sonoisa commited on
Commit
346ff58
·
verified ·
1 Parent(s): 1396cfc

Add RAG functionality

Browse files
Files changed (1) hide show
  1. index.html +437 -107
index.html CHANGED
@@ -34,6 +34,10 @@ https://opensource.org/license/mit/
34
  .gallery-item > .gallery {
35
  max-width: 380px;
36
  }
 
 
 
 
37
  </style>
38
  </head>
39
  <body>
@@ -41,10 +45,12 @@ https://opensource.org/license/mit/
41
  <gradio-requirements>
42
  pdfminer.six==20231228
43
  pyodide-http==0.2.1
 
 
44
  </gradio-requirements>
45
 
46
  <gradio-file name="chat_history.json">
47
- [[null, "ようこそ! PDFのテキストを参照しながら対話できるチャットボットです。\nPDFファイルをアップロードするとテキストが抽出されます。\nメッセージの中に{context}と書くと、抽出されたテキストがその部分に埋め込まれて対話が行われます。一番下のExamplesにその例があります。\nメッセージを書くときにShift+Enterを入力すると改行できます。"]]
48
  </gradio-file>
49
 
50
  <gradio-file name="app.py" entrypoint>
@@ -54,21 +60,6 @@ import os
54
  os.putenv("GRADIO_ANALYTICS_ENABLED", "False")
55
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
56
 
57
- import gradio as gr
58
- import base64
59
- from pathlib import Path
60
- import json
61
-
62
- import pyodide_http
63
- pyodide_http.patch_all()
64
-
65
- from pdfminer.pdfinterp import PDFResourceManager
66
- from pdfminer.converter import TextConverter
67
- from pdfminer.pdfinterp import PDFPageInterpreter
68
- from pdfminer.pdfpage import PDFPage
69
- from pdfminer.layout import LAParams
70
- from io import StringIO
71
-
72
  # openaiライブラリのインストール方法は https://github.com/pyodide/pyodide/issues/4292 を参考にしました。
73
  import micropip
74
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/multidict/multidict-4.7.6-py3-none-any.whl", keep_going=True)
@@ -86,6 +77,34 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
86
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  class URLLib3Transport(httpx.BaseTransport):
90
  """
91
  urllib3を使用してhttpxのリクエストを処理するカスタムトランスポートクラス
@@ -101,14 +120,23 @@ class URLLib3Transport(httpx.BaseTransport):
101
 
102
  http_client = httpx.Client(transport=URLLib3Transport())
103
 
104
- from openai import OpenAI, AzureOpenAI
105
- import tiktoken
 
 
 
 
 
 
106
 
107
 
108
  OPENAI_TOKENIZER = tiktoken.get_encoding("cl100k_base")
 
 
 
109
 
110
 
111
- def extract_text(pdf_filename):
112
  """
113
  PDFファイルからテキストを抽出する。
114
 
@@ -116,8 +144,10 @@ def extract_text(pdf_filename):
116
  pdf_filename (str): 抽出するPDFファイルのパス
117
 
118
  Returns:
119
- str: PDFファイルから抽出されたテキスト
120
  """
 
 
121
  with open(pdf_filename, "rb") as pdf_file:
122
  output = StringIO()
123
  resource_manager = PDFResourceManager()
@@ -125,32 +155,79 @@ def extract_text(pdf_filename):
125
  text_converter = TextConverter(resource_manager, output, laparams=laparams)
126
  page_interpreter = PDFPageInterpreter(resource_manager, text_converter)
127
 
 
128
  for i_page in PDFPage.get_pages(pdf_file):
129
  try:
 
130
  page_interpreter.process_page(i_page)
 
 
 
 
 
131
  except Exception as e:
132
- # print(e)
133
  pass
134
 
135
- output_text = output.getvalue()
136
  output.close()
137
  text_converter.close()
138
- return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
141
- def get_character_count_info(char_count, token_count):
 
 
 
 
142
  """
143
  文字数とトークン数の情報を文字列で返す。
144
 
145
  Args:
146
- char_count (int): 文字数
147
- token_count (int): トークン数
148
 
149
  Returns:
150
  str: 文字数とトークン数の情報を含む文字列
151
  """
152
- return f"""{char_count:,} character{'s' if char_count > 1 else ''}
153
- {token_count:,} token{'s' if token_count > 1 else ''}"""
 
154
 
155
 
156
  def update_context_element(pdf_file_obj):
@@ -163,25 +240,82 @@ def update_context_element(pdf_file_obj):
163
  Returns:
164
  Tuple: コンテキストテキストボックスに格納する抽出されたテキスト情報と、その文字数情報
165
  """
166
- context = extract_text(pdf_file_obj.name)
167
- return gr.update(value=context, interactive=True), count_characters(context)
 
168
 
169
 
170
- def count_characters(text):
171
  """
172
  テキストの文字数とトークン数を計算する。
 
173
 
174
  Args:
175
- text (str): 文字数とトークン数を計算するテキスト
176
 
177
  Returns:
178
  str: 文字数とトークン数の情報を含む文字列
179
  """
 
 
 
180
  tokens = OPENAI_TOKENIZER.encode(text)
181
- return get_character_count_info(len(text), len(tokens))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
 
 
 
 
183
 
184
- def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  """
186
  ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
187
 
@@ -201,6 +335,59 @@ def process_prompt(prompt, history, context, platform, endpoint, azure_deploymen
201
  Returns:
202
  str: ChatGPTによる生成結果
203
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  try:
205
  messages = []
206
  for user_message, assistant_message in history:
@@ -227,20 +414,98 @@ def process_prompt(prompt, history, context, platform, endpoint, azure_deploymen
227
  http_client=http_client
228
  )
229
 
230
- bot_response = ""
231
  completion = openai_client.chat.completions.create(
232
  messages=messages,
233
  model=model_name,
234
  max_tokens=max_tokens,
235
  temperature=temperature,
 
 
236
  stream=False
237
  )
238
 
 
239
  if hasattr(completion, "error"):
240
  raise gr.Error(completion.error["message"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  else:
242
- message = completion.choices[0].message
243
- bot_response += message.content
244
  yield bot_response
245
 
246
  except Exception as e:
@@ -280,11 +545,11 @@ def main():
280
  CHAT_HISTORY = []
281
 
282
  # localStorageから設定情報ををロードする。
283
- js_define_utilities_and_load_settings = """() => {
284
  const KEY_PREFIX = "serverless_chat_with_your_pdf:";
285
 
286
- const loadSettings = () => {
287
- const getItem = (key, defaultValue) => {
288
  const jsonValue = localStorage.getItem(KEY_PREFIX + key);
289
  if (jsonValue) {
290
  return JSON.parse(jsonValue);
@@ -305,7 +570,7 @@ def main():
305
  return [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url];
306
  };
307
 
308
- globalThis.resetSettings = () => {
309
  for (let key in localStorage) {
310
  if (key.startsWith(KEY_PREFIX)) {
311
  localStorage.removeItem(key);
@@ -315,7 +580,7 @@ def main():
315
  return loadSettings();
316
  };
317
 
318
- globalThis.saveItem = (key, value) => {
319
  localStorage.setItem(KEY_PREFIX + key, JSON.stringify(value));
320
  };
321
 
@@ -324,7 +589,7 @@ def main():
324
  """
325
 
326
  # should_saveがtrueであればURLにチャット履歴を保存し、falseであればチャット履歴を削除する。
327
- save_or_delete_chat_history = '''(hist, should_save) => {
328
  saveItem("save_chat_history_to_url", should_save);
329
  if (!should_save) {
330
  const url = new URL(window.location.href);
@@ -338,6 +603,108 @@ def main():
338
  }
339
  }'''
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app:
342
  with gr.Tabs():
343
  with gr.TabItem("Settings"):
@@ -346,20 +713,20 @@ def main():
346
  platform = gr.Radio(label="Platform", interactive=True,
347
  choices=["OpenAI", "Azure"], value="OpenAI")
348
  platform.change(None, inputs=platform, outputs=None,
349
- js='(x) => saveItem("platform", x)', show_progress="hidden")
350
 
351
  with gr.Row():
352
  endpoint = gr.Textbox(label="Endpoint", interactive=True)
353
  endpoint.change(None, inputs=endpoint, outputs=None,
354
- js='(x) => saveItem("endpoint", x)', show_progress="hidden")
355
 
356
  azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True)
357
  azure_deployment.change(None, inputs=azure_deployment, outputs=None,
358
- js='(x) => saveItem("azure_deployment", x)', show_progress="hidden")
359
 
360
  azure_api_version = gr.Textbox(label="Azure API Version", interactive=True)
361
  azure_api_version.change(None, inputs=azure_api_version, outputs=None,
362
- js='(x) => saveItem("azure_api_version", x)', show_progress="hidden")
363
 
364
  with gr.Row():
365
  api_key_file = gr.File(file_count="single", file_types=["text"],
@@ -367,45 +734,47 @@ def main():
367
  api_key = gr.Textbox(label="API Key", type="password", interactive=True)
368
  # 注意: 秘密情報をlocalStorageに保存してはならない。他者に秘密情報が盗まれる危険性があるからである。
369
 
370
- api_key_file.upload(fn=load_api_key, inputs=api_key_file, outputs=api_key,
371
  show_progress="hidden")
372
- api_key_file.clear(fn=lambda: None, inputs=None, outputs=api_key, show_progress="hidden")
373
 
374
  model_name = gr.Textbox(label="model", interactive=True)
375
  model_name.change(None, inputs=model_name, outputs=None,
376
- js='(x) => saveItem("model_name", x)', show_progress="hidden")
377
 
378
  max_tokens = gr.Number(label="Max Tokens", interactive=True,
379
  minimum=0, precision=0, step=1)
380
  max_tokens.change(None, inputs=max_tokens, outputs=None,
381
- js='(x) => saveItem("max_tokens", x)', show_progress="hidden")
382
 
383
  temperature = gr.Slider(label="Temperature", interactive=True,
384
  minimum=0.0, maximum=1.0, step=0.1)
385
  temperature.change(None, inputs=temperature, outputs=None,
386
- js='(x) => saveItem("temperature", x)', show_progress="hidden")
387
 
388
  save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
389
 
390
- setting_items = [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url]
 
391
  reset_button = gr.Button("Reset Settings")
392
  reset_button.click(None, inputs=None, outputs=setting_items,
393
- js="() => resetSettings()", show_progress="hidden")
394
 
395
  with gr.TabItem("Chat"):
396
  with gr.Row():
397
  with gr.Column(scale=1):
398
  pdf_file = gr.File(file_count="single", file_types=[".pdf"],
399
  height=80, label="PDF")
400
- context = gr.Textbox(label="Context", lines=20,
401
  interactive=True, autoscroll=False, show_copy_button=True)
402
- char_counter = gr.Textbox(label="Statistics", value=get_character_count_info(0, 0),
403
  lines=2, max_lines=2, interactive=False, container=True)
404
 
405
- pdf_file.upload(fn=update_context_element, inputs=pdf_file, outputs=[context, char_counter])
406
- pdf_file.clear(fn=lambda: None, inputs=None, outputs=context, show_progress="hidden")
407
 
408
- context.change(fn=count_characters, inputs=context, outputs=char_counter, show_progress="hidden")
 
409
 
410
  with gr.Column(scale=2):
411
  chatbot = gr.Chatbot(
@@ -415,7 +784,7 @@ def main():
415
  avatar_images=[None, Path("robot.png")])
416
 
417
  chat_message_textbox = gr.Textbox(placeholder="Type a message...",
418
- render=False, container=False, scale=7)
419
 
420
  chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
421
  # チャット履歴をクエリパラメータに保存する。
@@ -428,53 +797,14 @@ def main():
428
  title="Chat with your PDF",
429
  chatbot=chatbot,
430
  textbox=chat_message_textbox,
431
- additional_inputs=[context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature],
432
- examples=[['''制約条件に従い、以下の研究論文で提案されている技術や手法について要約してください。
433
-
434
- # 制約条件
435
- * 要約者: 大学教授
436
- * 想定読者: 大学院生
437
- * 要約結果の言語: 日本語
438
- * 要約結果の構成(以下の各項目について500文字):
439
- 1. どんな研究であるか
440
- 2. 先行研究に比べて優れている点は何か
441
- 3. 提案されている技術や手法の重要な点は何か
442
- 4. どのような方法で有効であると評価したか
443
- 5. 何か議論はあるか
444
- 6. 次に読むべき論文は何か
445
-
446
- # 研究論文
447
- """
448
- {context}
449
- """
450
-
451
- # 要約結果'''], ['''制約条件に従い、以下の文書の内容を要約してください。
452
-
453
- # 制約条件
454
- * 要約者: 大学教授
455
- * 想定読者: 大学院生
456
- * 形式: 箇条書き
457
- * 分量: 20項目
458
- * 要約結果の言語: 日本語
459
-
460
- # 文書
461
- """
462
- {context}
463
- """
464
-
465
- # 要約'''], ['''制約条件に従い、以下の文書から情報を抽出してください。
466
-
467
- # 制約条件
468
- * 抽出する情報: 課題や問題点について言及している全ての文。一つも見落とさないでください。
469
- * 出力形式: 箇条書き
470
- * 出力言語: 元の言語の文章と、その日本語訳
471
-
472
- # 文書
473
- """
474
- {context}
475
- """
476
-
477
- # 抽出結果'''], ["続きを生成してください。"]])
478
 
479
  app.load(None, inputs=None, outputs=setting_items,
480
  js=js_define_utilities_and_load_settings, show_progress="hidden")
 
34
  .gallery-item > .gallery {
35
  max-width: 380px;
36
  }
37
+
38
+ #context > label > textarea {
39
+ scrollbar-width: thin !important;
40
+ }
41
  </style>
42
  </head>
43
  <body>
 
45
  <gradio-requirements>
46
  pdfminer.six==20231228
47
  pyodide-http==0.2.1
48
+ janome==0.5.0
49
+ rank_bm25==0.2.2
50
  </gradio-requirements>
51
 
52
  <gradio-file name="chat_history.json">
53
+ [[null, "ようこそ! PDFのテキストを参照しながら対話できるチャットボットです。\nPDFファイルをアップロードするとテキストが抽出されます。\nメッセージの中に{context}と書くと、抽出されたテキストがその部分に埋め込まれて対話が行われます。他にもPDFのページを検索して参照したり、ページ番号を指定して参照したりすることができます。一番下のExamplesにこれらの例があります。\nメッセージを書くときにShift+Enterを入力すると改行できます。"]]
54
  </gradio-file>
55
 
56
  <gradio-file name="app.py" entrypoint>
 
60
  os.putenv("GRADIO_ANALYTICS_ENABLED", "False")
61
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # openaiライブラリのインストール方法は https://github.com/pyodide/pyodide/issues/4292 を参考にしました。
64
  import micropip
65
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/multidict/multidict-4.7.6-py3-none-any.whl", keep_going=True)
 
77
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
78
 
79
 
80
+ import gradio as gr
81
+ import base64
82
+ import json
83
+ import unicodedata
84
+ import re
85
+ from pathlib import Path
86
+ from dataclasses import dataclass
87
+ import asyncio
88
+
89
+ import pyodide_http
90
+ pyodide_http.patch_all()
91
+
92
+ from pdfminer.pdfinterp import PDFResourceManager
93
+ from pdfminer.converter import TextConverter
94
+ from pdfminer.pdfinterp import PDFPageInterpreter
95
+ from pdfminer.pdfpage import PDFPage
96
+ from pdfminer.layout import LAParams
97
+ from io import StringIO
98
+
99
+ from janome.tokenizer import Tokenizer as JanomeTokenizer
100
+ from janome.analyzer import Analyzer as JanomeAnalyzer
101
+ from janome.tokenfilter import POSStopFilter, LowerCaseFilter
102
+ from rank_bm25 import BM25Okapi
103
+
104
+ from openai import OpenAI, AzureOpenAI
105
+ import tiktoken
106
+
107
+
108
  class URLLib3Transport(httpx.BaseTransport):
109
  """
110
  urllib3を使用してhttpxのリクエストを処理するカスタムトランスポートクラス
 
120
 
121
  http_client = httpx.Client(transport=URLLib3Transport())
122
 
123
+
124
+ @dataclass
125
+ class Page:
126
+ """
127
+ PDFのページ内容
128
+ """
129
+ number: int
130
+ content: str
131
 
132
 
133
  OPENAI_TOKENIZER = tiktoken.get_encoding("cl100k_base")
134
+ JANOME_TOKENIZER = JanomeTokenizer()
135
+ JANOME_ANALYZER = JanomeAnalyzer(tokenizer=JANOME_TOKENIZER,
136
+ token_filters=[POSStopFilter(["記号,空白"]), LowerCaseFilter()])
137
 
138
 
139
+ def extract_pdf_pages(pdf_filename):
140
  """
141
  PDFファイルからテキストを抽出する。
142
 
 
144
  pdf_filename (str): 抽出するPDFファイルのパス
145
 
146
  Returns:
147
+ list[Page]: PDFの各ページ内容のリスト
148
  """
149
+
150
+ pages = []
151
  with open(pdf_filename, "rb") as pdf_file:
152
  output = StringIO()
153
  resource_manager = PDFResourceManager()
 
155
  text_converter = TextConverter(resource_manager, output, laparams=laparams)
156
  page_interpreter = PDFPageInterpreter(resource_manager, text_converter)
157
 
158
+ page_number = 0
159
  for i_page in PDFPage.get_pages(pdf_file):
160
  try:
161
+ page_number += 1
162
  page_interpreter.process_page(i_page)
163
+ page_content = output.getvalue()
164
+ page_content = unicodedata.normalize('NFKC', page_content)
165
+ pages.append(Page(number=page_number, content=page_content))
166
+ output.truncate(0)
167
+ output.seek(0)
168
  except Exception as e:
169
+ print(e)
170
  pass
171
 
 
172
  output.close()
173
  text_converter.close()
174
+
175
+ return pages
176
+
177
+
178
+ def merge_pages_with_page_tag(pages):
179
+ """
180
+ PDFの各ページ内容を一つの文字列にマージする。
181
+ ただし、chatpdf:pageというタグでページを括る。
182
+ extract_pages_from_page_tag()の逆変換である。
183
+
184
+ Args:
185
+ pages (list[Page]): PDFの各ページ内容のリスト
186
+
187
+ Returns:
188
+ str: PDFの各ページ内容をマージした文字列
189
+ """
190
+ document_with_page_tag = ""
191
+ for page in pages:
192
+ document_with_page_tag += f'&lt;chatpdf:page number="{page.number}"&gt;\n{page.content}\n&lt;/chatpdf:page&gt;\n'
193
+ return document_with_page_tag
194
+
195
+
196
+ def extract_pages_from_page_tag(document_with_page_tag):
197
+ """
198
+ chatpdf:pageというタグで括られた領域をPDFのページ内容と解釈して、Pageオブジェクトのリストに変換する。
199
+ merge_pages_with_page_tag()の逆変換である。
200
+
201
+ Args:
202
+ document_with_page_tag (str): chatpdf:pageというタグで各ページが括られた文字列
203
+
204
+ Returns:
205
+ list[Page]: Pageオブジェクトのリスト
206
+ """
207
+ page_tag_pattern = r'&lt;chatpdf:page number="(\d+)"&gt;\n?(.*?)\n?&lt;\/chatpdf:page&gt;\n?'
208
+ matches = re.findall(page_tag_pattern, document_with_page_tag, re.DOTALL)
209
+ pages = [Page(number=int(number), content=content) for number, content in matches]
210
+ return pages
211
 
212
 
213
+ def add_s(values):
214
+ return "s" if len(values) > 1 else ""
215
+
216
+
217
+ def get_context_info(characters, tokens):
218
  """
219
  文字数とトークン数の情報を文字列で返す。
220
 
221
  Args:
222
+ characters (str): テキスト
223
+ tokens (list[str]): トークン
224
 
225
  Returns:
226
  str: 文字数とトークン数の情報を含む文字列
227
  """
228
+ char_count = len(characters)
229
+ token_count = len(tokens)
230
+ return f"{char_count:,} character{add_s(characters)}\n{token_count:,} token{add_s(tokens)}"
231
 
232
 
233
  def update_context_element(pdf_file_obj):
 
240
  Returns:
241
  Tuple: コンテキストテキストボックスに格納する抽出されたテキスト情報と、その文字数情報
242
  """
243
+ pages = extract_pdf_pages(pdf_file_obj.name)
244
+ document_with_tag = merge_pages_with_page_tag(pages)
245
+ return gr.update(value=document_with_tag, interactive=True), count_characters(document_with_tag)
246
 
247
 
248
+ def count_characters(document_with_tag):
249
  """
250
  テキストの文字数とトークン数を計算する。
251
+ ただし、テキストはchatpdf:pageというタグでページが括られているとする。
252
 
253
  Args:
254
+ document_with_tag (str): 文字数とトークン数を計算するテキスト
255
 
256
  Returns:
257
  str: 文字数とトークン数の情報を含む文字列
258
  """
259
+
260
+ text = "".join([page.content for page in extract_pages_from_page_tag(document_with_tag)])
261
+
262
  tokens = OPENAI_TOKENIZER.encode(text)
263
+ return get_context_info(text, tokens)
264
+
265
+
266
+ class SearchEngine:
267
+ def __init__(self, engine, pages):
268
+ self.engine = engine
269
+ self.pages = pages
270
+
271
+
272
+ SEARCH_ENGINE = None
273
+
274
+ def create_search_engine(context):
275
+ global SEARCH_ENGINE
276
+
277
+ pages = extract_pages_from_page_tag(context)
278
+ tokenized_pages = []
279
+ original_pages = []
280
+ for page in pages:
281
+ page_content = page.content.strip()
282
+ if page_content:
283
+ tokenized_page = [token.base_form for token in JANOME_ANALYZER.analyze(page_content)]
284
+ if tokenized_page:
285
+ tokenized_pages.append(tokenized_page)
286
+ original_pages.append(page)
287
+
288
+ if tokenized_pages:
289
+ bm25 = BM25Okapi(tokenized_pages)
290
+ SEARCH_ENGINE = SearchEngine(engine=bm25, pages=original_pages)
291
+ else:
292
+ SEARCH_ENGINE = None
293
+
294
 
295
+ def search_pages(keywords, page_limit):
296
+ global SEARCH_ENGINE
297
+ if SEARCH_ENGINE is None:
298
+ return []
299
 
300
+ tokenized_query = [token.base_form for token in JANOME_ANALYZER.analyze(keywords)]
301
+ if not tokenized_query:
302
+ return []
303
+
304
+ found_pages = SEARCH_ENGINE.engine.get_top_n(tokenized_query, SEARCH_ENGINE.pages, n=page_limit)
305
+ return found_pages
306
+
307
+
308
+ def load_pages(page_numbers):
309
+ global SEARCH_ENGINE
310
+ if SEARCH_ENGINE is None:
311
+ return []
312
+
313
+ page_numbers = set(page_numbers)
314
+ found_pages = [page for page in SEARCH_ENGINE.pages if page.number in page_numbers]
315
+ return found_pages
316
+
317
+
318
+ async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
319
  """
320
  ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。
321
 
 
335
  Returns:
336
  str: ChatGPTによる生成結果
337
  """
338
+
339
+ pages = extract_pages_from_page_tag(context)
340
+ if pages:
341
+ context = "".join([page.content for page in pages])
342
+
343
+ tools = [
344
+ # ページ検索
345
+ {
346
+ "type": "function",
347
+ "function": {
348
+ "name": "search_pages",
349
+ "description": "与えられたキーワードを含むページを検索します。",
350
+ "parameters": {
351
+ "type": "object",
352
+ "properties": {
353
+ "keywords": {
354
+ "type": "string",
355
+ "description": "半角空白区切りの複数の検索キーワード, 例: Artificial General Intelligence 自律エージェント"
356
+ },
357
+ "page_limit": {
358
+ "type": "number",
359
+ "description": "検索するページ数, 例: 3",
360
+ "minimum": 1
361
+ }
362
+ }
363
+ },
364
+ "required": ["keywords"]
365
+ }
366
+ },
367
+ # ページ取得
368
+ {
369
+ "type": "function",
370
+ "function": {
371
+ "name": "load_pages",
372
+ "description": "与えられたページ番号のページを取得します。",
373
+ "parameters": {
374
+ "type": "object",
375
+ "properties": {
376
+ "page_numbers": {
377
+ "type": "array",
378
+ "items": {
379
+ "type": "number"
380
+ },
381
+ "minItems": 1,
382
+ "description": "取得するページのページ番号のリスト"
383
+ }
384
+ }
385
+ },
386
+ "required": ["page_numbers"]
387
+ }
388
+ }
389
+ ]
390
+
391
  try:
392
  messages = []
393
  for user_message, assistant_message in history:
 
414
  http_client=http_client
415
  )
416
 
 
417
  completion = openai_client.chat.completions.create(
418
  messages=messages,
419
  model=model_name,
420
  max_tokens=max_tokens,
421
  temperature=temperature,
422
+ tools=tools,
423
+ tool_choice="auto",
424
  stream=False
425
  )
426
 
427
+ bot_response = ""
428
  if hasattr(completion, "error"):
429
  raise gr.Error(completion.error["message"])
430
+
431
+ response_message = completion.choices[0].message
432
+ tool_calls = response_message.tool_calls
433
+ if tool_calls:
434
+ messages.append(response_message)
435
+
436
+ for tool_call in tool_calls:
437
+ function_name = tool_call.function.name
438
+ function_args = json.loads(tool_call.function.arguments)
439
+ if function_name == "search_pages":
440
+ # ページ検索
441
+ keywords = function_args.get("keywords").strip()
442
+ page_limit = function_args.get("page_limit") or 3
443
+
444
+ bot_response += f'Searching for pages containing the keyword{add_s(keywords.split(" "))} "{keywords}".\n'
445
+
446
+ found_pages = search_pages(keywords, page_limit)
447
+ function_response = json.dumps({
448
+ "status": "found" if found_pages else "not found",
449
+ "found_pages": [{
450
+ "page_number": page.number,
451
+ "page_content": page.content
452
+ } for page in found_pages]
453
+ }, ensure_ascii=False)
454
+ messages.append({
455
+ "tool_call_id": tool_call.id,
456
+ "role": "tool",
457
+ "name": function_name,
458
+ "content": function_response
459
+ })
460
+ if found_pages:
461
+ bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
462
+ else:
463
+ bot_response += "Page not found.\n\n"
464
+ elif function_name == "load_pages":
465
+ # ページ取得
466
+ page_numbers = function_args.get("page_numbers")
467
+
468
+ bot_response += f'Trying to load page{add_s(page_numbers)} {", ".join(map(str, page_numbers))}.\n'
469
+
470
+ found_pages = load_pages(page_numbers)
471
+ function_response = json.dumps({
472
+ "status": "found" if found_pages else "not found",
473
+ "found_pages": [{
474
+ "page_number": page.number,
475
+ "page_content": page.content
476
+ } for page in found_pages]
477
+ }, ensure_ascii=False)
478
+ messages.append({
479
+ "tool_call_id": tool_call.id,
480
+ "role": "tool",
481
+ "name": function_name,
482
+ "content": function_response
483
+ })
484
+ if found_pages:
485
+ bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
486
+ else:
487
+ bot_response += "Page not found.\n\n"
488
+
489
+ yield bot_response + "Generating response. Please wait a moment...\n"
490
+ await asyncio.sleep(0.1)
491
+
492
+ completion = openai_client.chat.completions.create(
493
+ messages=messages,
494
+ model=model_name,
495
+ max_tokens=max_tokens,
496
+ temperature=temperature,
497
+ stream=False
498
+ )
499
+
500
+ if hasattr(completion, "error"):
501
+ raise gr.Error(completion.error["message"])
502
+
503
+ response_message = completion.choices[0].message
504
+ bot_response += response_message.content
505
+ yield bot_response
506
+
507
  else:
508
+ bot_response += response_message.content
 
509
  yield bot_response
510
 
511
  except Exception as e:
 
545
  CHAT_HISTORY = []
546
 
547
  # localStorageから設定情報ををロードする。
548
+ js_define_utilities_and_load_settings = """() =&gt; {
549
  const KEY_PREFIX = "serverless_chat_with_your_pdf:";
550
 
551
+ const loadSettings = () =&gt; {
552
+ const getItem = (key, defaultValue) =&gt; {
553
  const jsonValue = localStorage.getItem(KEY_PREFIX + key);
554
  if (jsonValue) {
555
  return JSON.parse(jsonValue);
 
570
  return [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url];
571
  };
572
 
573
+ globalThis.resetSettings = () =&gt; {
574
  for (let key in localStorage) {
575
  if (key.startsWith(KEY_PREFIX)) {
576
  localStorage.removeItem(key);
 
580
  return loadSettings();
581
  };
582
 
583
+ globalThis.saveItem = (key, value) =&gt; {
584
  localStorage.setItem(KEY_PREFIX + key, JSON.stringify(value));
585
  };
586
 
 
589
  """
590
 
591
  # should_saveがtrueであればURLにチャット履歴を保存し、falseであればチャット履歴を削除する。
592
+ save_or_delete_chat_history = '''(hist, should_save) =&gt; {
593
  saveItem("save_chat_history_to_url", should_save);
594
  if (!should_save) {
595
  const url = new URL(window.location.href);
 
603
  }
604
  }'''
605
 
606
+ # プロンプト例
607
+ examples = {
608
+ "要約 (論文)": '''制約条件に従い、以下の研究論文で提案されている技術や手法について要約してください。
609
+
610
+ # 制約条件
611
+ * 要約者: 大学教授
612
+ * 想定読者: 大学院生
613
+ * 要約結果の言語: 日本語
614
+ * 要約結果の構成(以下の各項目について500文字):
615
+ 1. どんな研究であるか
616
+ 2. 先行研究に比べて優れている点は何か
617
+ 3. 提案されている技術や手法の重要な点は何か
618
+ 4. どのような方法で有効であると評価したか
619
+ 5. 何か議論はあるか
620
+ 6. 次に読むべき論文は何か
621
+
622
+ # 研究論文
623
+ """
624
+ {context}
625
+ """
626
+
627
+ # 要約結果''',
628
+ "要約 (一般)": '''制約条件に従い、以下の文書の内容を要約してください。
629
+
630
+ # 制約条件
631
+ * 要約者: 技術コンサルタント
632
+ * 想定読者: 経営層、CTO、CIO
633
+ * 形式: 箇条書き
634
+ * 分量: 20項目
635
+ * 要約結果の言語: 日本語
636
+
637
+ # 文書
638
+ """
639
+ {context}
640
+ """
641
+
642
+ # 要約''',
643
+ "情報抽出": '''制約条件に従い、以下の文書から情報を抽出してください。
644
+
645
+ # 制約条件
646
+ * 抽出する情報: 課題や問題点について言及している全ての文。一つも見落とさないでください。
647
+ * 出力形式: 箇条書き
648
+ * 出力言語: 元の言語の文章と、その日本語訳
649
+
650
+ # 文書
651
+ """
652
+ {context}
653
+ """
654
+
655
+ # 抽出結果''',
656
+ "QA (RAG)": '''次の質問に回答するために役立つページを検索して、その検索結果を使って回答して下さい。
657
+
658
+ # 制約条件
659
+ * 検索クエリの生成方法: 質問文の3つの言い換え(paraphrase)をカンマ区切りで連結した文字列
660
+ * 検索クエリの言語: 英語
661
+ * 検索するページ数: 3
662
+ * 回答方法:
663
+ - 検索結果の情報のみを用いて回答すること。
664
+ - 回答に利用した文章のあるページ番号を最後に出力すること。形式: "参考ページ番号: 71, 59, 47"
665
+ - 回答に役立つ情報が検索結果内にない場合は「検索結果には回答に役立つ情報がありませんでした。」と回答すること。
666
+ * 回答の言語: 日本語
667
+
668
+ # 質問
669
+ どのような方法で、提案された手法が有効であると評価しましたか?
670
+
671
+ # 回答''',
672
+ "要約 (RAG)": '''次のキーワードを含むページを検索して、その検索結果をページごとに要約して下さい。
673
+
674
+ # 制約条件
675
+ * キーワード: dataset datasets
676
+ * 検索するページ数: 3
677
+ * 要約結果の言語: 日本語
678
+ * 要約の形式:
679
+ ## ページ番号(例: 12ページ)
680
+ - 要約文1
681
+ - 要約文2
682
+ ...
683
+ * 要約の分量: 各ページ3項目
684
+
685
+ # 要約''',
686
+ "翻訳 (RAG)": '''次のキーワードを含むページを検索して、その検索結果を日本語に翻訳して下さい。
687
+
688
+ # 制約条件
689
+ * キーワード: dataset datasets
690
+ * 検索するページ数: 1
691
+
692
+ # 翻訳結果''',
693
+ "要約 (ページ指定)": '''16〜17ページをページごとに箇条書きで要約して下さい。
694
+
695
+ # 制約条件
696
+ * 要約結果の言語: 日本語
697
+ * 要約の形式:
698
+ ## ページ番号(例: 12ページ)
699
+ - 要約文1
700
+ - 要約文2
701
+ ...
702
+ * 要約の分量: 各ページ5項目
703
+
704
+ # 要約''',
705
+ "続きを生成": "続きを生成してください。"
706
+ }
707
+
708
  with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app:
709
  with gr.Tabs():
710
  with gr.TabItem("Settings"):
 
713
  platform = gr.Radio(label="Platform", interactive=True,
714
  choices=["OpenAI", "Azure"], value="OpenAI")
715
  platform.change(None, inputs=platform, outputs=None,
716
+ js='(x) =&gt; saveItem("platform", x)', show_progress="hidden")
717
 
718
  with gr.Row():
719
  endpoint = gr.Textbox(label="Endpoint", interactive=True)
720
  endpoint.change(None, inputs=endpoint, outputs=None,
721
+ js='(x) =&gt; saveItem("endpoint", x)', show_progress="hidden")
722
 
723
  azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True)
724
  azure_deployment.change(None, inputs=azure_deployment, outputs=None,
725
+ js='(x) =&gt; saveItem("azure_deployment", x)', show_progress="hidden")
726
 
727
  azure_api_version = gr.Textbox(label="Azure API Version", interactive=True)
728
  azure_api_version.change(None, inputs=azure_api_version, outputs=None,
729
+ js='(x) =&gt; saveItem("azure_api_version", x)', show_progress="hidden")
730
 
731
  with gr.Row():
732
  api_key_file = gr.File(file_count="single", file_types=["text"],
 
734
  api_key = gr.Textbox(label="API Key", type="password", interactive=True)
735
  # 注意: 秘密情報をlocalStorageに保存してはならない。他者に秘密情報が盗まれる危険性があるからである。
736
 
737
+ api_key_file.upload(load_api_key, inputs=api_key_file, outputs=api_key,
738
  show_progress="hidden")
739
+ api_key_file.clear(lambda: None, inputs=None, outputs=api_key, show_progress="hidden")
740
 
741
  model_name = gr.Textbox(label="model", interactive=True)
742
  model_name.change(None, inputs=model_name, outputs=None,
743
+ js='(x) =&gt; saveItem("model_name", x)', show_progress="hidden")
744
 
745
  max_tokens = gr.Number(label="Max Tokens", interactive=True,
746
  minimum=0, precision=0, step=1)
747
  max_tokens.change(None, inputs=max_tokens, outputs=None,
748
+ js='(x) =&gt; saveItem("max_tokens", x)', show_progress="hidden")
749
 
750
  temperature = gr.Slider(label="Temperature", interactive=True,
751
  minimum=0.0, maximum=1.0, step=0.1)
752
  temperature.change(None, inputs=temperature, outputs=None,
753
+ js='(x) =&gt; saveItem("temperature", x)', show_progress="hidden")
754
 
755
  save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
756
 
757
+ setting_items = [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens,
758
+ temperature, save_chat_history_to_url]
759
  reset_button = gr.Button("Reset Settings")
760
  reset_button.click(None, inputs=None, outputs=setting_items,
761
+ js="() =&gt; resetSettings()", show_progress="hidden")
762
 
763
  with gr.TabItem("Chat"):
764
  with gr.Row():
765
  with gr.Column(scale=1):
766
  pdf_file = gr.File(file_count="single", file_types=[".pdf"],
767
  height=80, label="PDF")
768
+ context = gr.Textbox(elem_id="context", label="Context", lines=20,
769
  interactive=True, autoscroll=False, show_copy_button=True)
770
+ char_counter = gr.Textbox(label="Statistics", value=get_context_info("", []),
771
  lines=2, max_lines=2, interactive=False, container=True)
772
 
773
+ pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter])
774
+ pdf_file.clear(lambda: None, inputs=None, outputs=context, show_progress="hidden")
775
 
776
+ (context.change(count_characters, inputs=context, outputs=char_counter, show_progress="hidden")
777
+ .then(create_search_engine, inputs=context, outputs=None))
778
 
779
  with gr.Column(scale=2):
780
  chatbot = gr.Chatbot(
 
784
  avatar_images=[None, Path("robot.png")])
785
 
786
  chat_message_textbox = gr.Textbox(placeholder="Type a message...",
787
+ render=False, container=False, interactive=True, scale=7)
788
 
789
  chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
790
  # チャット履歴をクエリパラメータに保存する。
 
797
  title="Chat with your PDF",
798
  chatbot=chatbot,
799
  textbox=chat_message_textbox,
800
+ additional_inputs=[context, platform, endpoint, azure_deployment, azure_api_version, api_key,
801
+ model_name, max_tokens, temperature],
802
+ examples=None)
803
+
804
+ example_title_textbox = gr.Textbox(visible=False, interactive=True)
805
+ gr.Examples([[k] for k, v in examples.items()],
806
+ inputs=example_title_textbox, outputs=chat_message_textbox,
807
+ fn=lambda title: examples[title], run_on_click=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
  app.load(None, inputs=None, outputs=setting_items,
810
  js=js_define_utilities_and_load_settings, show_progress="hidden")