|
|
|
|
|
|
|
|
|
|
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1"> |
|
<title>Chat with your PDF</title> |
|
<meta name="description" content="Chat with your PDF"> |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@gradio/[email protected]/dist/lite.css" /> |
|
<style> |
|
html, body { |
|
margin: 0; |
|
padding: 0; |
|
height: 100%; |
|
background: var(--body-background-fill); |
|
} |
|
|
|
footer { |
|
display: none !important; |
|
} |
|
|
|
#chatbot { |
|
height: auto !important; |
|
min-height: 500px; |
|
} |
|
|
|
.chatbot { |
|
white-space: pre-wrap; |
|
} |
|
|
|
.gallery-item > .gallery { |
|
max-width: 380px; |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<gradio-lite> |
|
<gradio-requirements> |
|
pdfminer.six==20231228 |
|
pyodide-http==0.2.1 |
|
</gradio-requirements> |
|
|
|
<gradio-file name="chat_history.json"> |
|
[[null, "ようこそ! PDFのテキストを参照しながら対話できるチャットボットです。\nPDFファイルをアップロードするとテキストが抽出されます。\nメッセージの中に{context}と書くと、抽出されたテキストがその部分に埋め込まれて対話が行われます。一番下のExamplesにその例があります。\nメッセージを書くときにShift+Enterを入力すると改行できます。"]] |
|
</gradio-file> |
|
|
|
<gradio-file name="app.py" entrypoint> |
|
import os |
|
|
|
# Gradioによるアナリティクスを無効化 |
|
os.putenv("GRADIO_ANALYTICS_ENABLED", "False") |
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
import gradio as gr |
|
import base64 |
|
from pathlib import Path |
|
import json |
|
|
|
import pyodide_http |
|
pyodide_http.patch_all() |
|
|
|
from pdfminer.pdfinterp import PDFResourceManager |
|
from pdfminer.converter import TextConverter |
|
from pdfminer.pdfinterp import PDFPageInterpreter |
|
from pdfminer.pdfpage import PDFPage |
|
from pdfminer.layout import LAParams |
|
from io import StringIO |
|
|
|
# openaiライブラリのインストール方法は https://github.com/pyodide/pyodide/issues/4292 を参考にしました。 |
|
import micropip |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/multidict/multidict-4.7.6-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/frozenlist/frozenlist-1.4.0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/aiohttp/aiohttp-4.0.0a2.dev0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/openai/openai-1.3.7-py3-none-any.whl", keep_going=True) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True) |
|
await micropip.install("ssl") |
|
import ssl |
|
await micropip.install("httpx", keep_going=True) |
|
import httpx |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/urllib3/urllib3-2.1.0-py3-none-any.whl", keep_going=True) |
|
import urllib3 |
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) |
|
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True) |
|
|
|
|
|
class URLLib3Transport(httpx.BaseTransport): |
|
""" |
|
urllib3を使用してhttpxのリクエストを処理するカスタムトランスポートクラス |
|
""" |
|
def __init__(self): |
|
self.pool = urllib3.PoolManager() |
|
|
|
def handle_request(self, request: httpx.Request): |
|
payload = json.loads(request.content.decode("utf-8")) |
|
urllib3_response = self.pool.request(request.method, str(request.url), headers=request.headers, json=payload) |
|
stream = httpx.ByteStream(urllib3_response.data) |
|
return httpx.Response(urllib3_response.status, headers=urllib3_response.headers, stream=stream) |
|
|
|
http_client = httpx.Client(transport=URLLib3Transport()) |
|
|
|
from openai import OpenAI, AzureOpenAI |
|
import tiktoken |
|
|
|
|
|
OPENAI_TOKENIZER = tiktoken.get_encoding("cl100k_base") |
|
|
|
|
|
def extract_text(pdf_filename): |
|
""" |
|
PDFファイルからテキストを抽出する。 |
|
|
|
Args: |
|
pdf_filename (str): 抽出するPDFファイルのパス |
|
|
|
Returns: |
|
str: PDFファイルから抽出されたテキスト |
|
""" |
|
with open(pdf_filename, "rb") as pdf_file: |
|
output = StringIO() |
|
resource_manager = PDFResourceManager() |
|
laparams = LAParams() |
|
text_converter = TextConverter(resource_manager, output, laparams=laparams) |
|
page_interpreter = PDFPageInterpreter(resource_manager, text_converter) |
|
|
|
for i_page in PDFPage.get_pages(pdf_file): |
|
try: |
|
page_interpreter.process_page(i_page) |
|
except Exception as e: |
|
# print(e) |
|
pass |
|
|
|
output_text = output.getvalue() |
|
output.close() |
|
text_converter.close() |
|
return output_text |
|
|
|
|
|
def get_character_count_info(char_count, token_count): |
|
""" |
|
文字数とトークン数の情報を文字列で返す。 |
|
|
|
Args: |
|
char_count (int): 文字数 |
|
token_count (int): トークン数 |
|
|
|
Returns: |
|
str: 文字数とトークン数の情報を含む文字列 |
|
""" |
|
return f"""{char_count:,} character{'s' if char_count > 1 else ''} |
|
{token_count:,} token{'s' if token_count > 1 else ''}""" |
|
|
|
|
|
def update_context_element(pdf_file_obj): |
|
""" |
|
PDFファイルからテキストを抽出し、コンテキスト要素を更新する。 |
|
|
|
Args: |
|
pdf_file_obj (File): アップロードされたPDFファイルオブジェクト |
|
|
|
Returns: |
|
Tuple: コンテキストテキストボックスに格納する抽出されたテキスト情報と、その文字数情報 |
|
""" |
|
context = extract_text(pdf_file_obj.name) |
|
return gr.update(value=context, interactive=True), count_characters(context) |
|
|
|
|
|
def count_characters(text): |
|
""" |
|
テキストの文字数とトークン数を計算する。 |
|
|
|
Args: |
|
text (str): 文字数とトークン数を計算するテキスト |
|
|
|
Returns: |
|
str: 文字数とトークン数の情報を含む文字列 |
|
""" |
|
tokens = OPENAI_TOKENIZER.encode(text) |
|
return get_character_count_info(len(text), len(tokens)) |
|
|
|
|
|
def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature): |
|
""" |
|
ユーザーのプロンプトを処理し、ChatGPTによる生成結果を返す。 |
|
|
|
Args: |
|
prompt (str): ユーザーからの入力プロンプト |
|
history (list): チャット履歴 |
|
context (str): チャットコンテキスト |
|
platform (str): 使用するAIプラットフォーム |
|
endpoint (str): AIサービスのエンドポイント |
|
azure_deployment (str): Azureのデプロイメント名 |
|
azure_api_version (str): Azure APIのバージョン |
|
api_key (str): APIキー |
|
model_name (str): 使用するAIモデルの名前 |
|
max_tokens (int): 生成する最大トークン数 |
|
temperature (float): クリエイティビティの度合いを示す温度パラメータ |
|
|
|
Returns: |
|
str: ChatGPTによる生成結果 |
|
""" |
|
try: |
|
messages = [] |
|
for user_message, assistant_message in history: |
|
if user_message is not None and assistant_message is not None: |
|
messages.append({ "role": "user", "content": user_message }) |
|
messages.append({ "role": "assistant", "content": assistant_message }) |
|
|
|
prompt = prompt.replace("{context}", context) |
|
messages.append({ "role": "user", "content": prompt }) |
|
|
|
if platform == "OpenAI": |
|
openai_client = OpenAI( |
|
base_url=endpoint, |
|
api_key=api_key, |
|
http_client=http_client |
|
) |
|
else: # Azure |
|
openai_client = AzureOpenAI( |
|
azure_endpoint=endpoint, |
|
api_version=azure_api_version, |
|
azure_deployment=azure_deployment, |
|
api_key=api_key, |
|
http_client=http_client |
|
) |
|
|
|
completion = openai_client.chat.completions.create( |
|
messages=messages, |
|
model=model_name, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
stream=False |
|
) |
|
|
|
if hasattr(completion, "error"): |
|
raise gr.Error(completion.error["message"]) |
|
else: |
|
message = completion.choices[0].message |
|
return message.content |
|
|
|
except Exception as e: |
|
if hasattr(e, "message"): |
|
raise gr.Error(e.message) |
|
else: |
|
raise gr.Error(str(e)) |
|
|
|
|
|
def load_api_key(file_obj): |
|
""" |
|
APIキーファイルからAPIキーを読み込む。 |
|
|
|
Args: |
|
file_obj (File): APIキーファイルオブジェクト |
|
|
|
Returns: |
|
str: 読み込まれたAPIキー文字列 |
|
""" |
|
try: |
|
with open(file_obj.name, "r", encoding="utf-8") as api_key_file: |
|
return api_key_file.read().strip() |
|
except Exception as e: |
|
raise gr.Error(str(e)) |
|
|
|
|
|
def main(): |
|
""" |
|
アプリケーションのメイン関数。Gradioインターフェースを設定し、アプリケーションを起動する。 |
|
""" |
|
try: |
|
# クエリパラメータに保存されていることもあるチャット履歴を読み出す。 |
|
with open("chat_history.json", "r", encoding="utf-8") as f: |
|
CHAT_HISTORY = json.load(f) |
|
except Exception as e: |
|
print(e) |
|
CHAT_HISTORY = [] |
|
|
|
# localStorageから設定情報ををロードする。 |
|
js_define_utilities_and_load_settings = """() => { |
|
const KEY_PREFIX = "serverless_chat_with_your_pdf:"; |
|
|
|
const loadSettings = () => { |
|
const getItem = (key, defaultValue) => { |
|
const jsonValue = localStorage.getItem(KEY_PREFIX + key); |
|
if (jsonValue) { |
|
return JSON.parse(jsonValue); |
|
} else { |
|
return defaultValue; |
|
} |
|
}; |
|
|
|
const platform = getItem("platform", "OpenAI"); |
|
const endpoint = getItem("endpoint", "https://api.openai.com/v1"); |
|
const azure_deployment = getItem("azure_deployment", ""); |
|
const azure_api_version = getItem("azure_api_version", ""); |
|
const model_name = getItem("model_name", "gpt-4-turbo-preview"); |
|
const max_tokens = getItem("max_tokens", 1024); |
|
const temperature = getItem("temperature", 0.2); |
|
const save_chat_history_to_url = getItem("save_chat_history_to_url", false); |
|
|
|
return [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url]; |
|
}; |
|
|
|
globalThis.resetSettings = () => { |
|
for (let key in localStorage) { |
|
if (key.startsWith(KEY_PREFIX)) { |
|
localStorage.removeItem(key); |
|
} |
|
} |
|
|
|
return loadSettings(); |
|
}; |
|
|
|
globalThis.saveItem = (key, value) => { |
|
localStorage.setItem(KEY_PREFIX + key, JSON.stringify(value)); |
|
}; |
|
|
|
return loadSettings(); |
|
} |
|
""" |
|
|
|
# should_saveがtrueであればURLにチャット履歴を保存し、falseであればチャット履歴を削除する。 |
|
save_or_delete_chat_history = '''(hist, should_save) => { |
|
saveItem("save_chat_history_to_url", should_save); |
|
if (!should_save) { |
|
const url = new URL(window.location.href); |
|
url.searchParams.delete("history"); |
|
window.history.replaceState({path:url.href}, '', url.href); |
|
} else { |
|
const compressedHistory = LZString.compressToEncodedURIComponent(JSON.stringify(hist)); |
|
const url = new URL(window.location.href); |
|
url.searchParams.set("history", compressedHistory); |
|
window.history.replaceState({path:url.href}, '', url.href); |
|
} |
|
}''' |
|
|
|
with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app: |
|
with gr.Tabs(): |
|
with gr.TabItem("Settings"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
platform = gr.Radio(label="Platform", interactive=True, |
|
choices=["OpenAI", "Azure"], value="OpenAI") |
|
platform.change(None, inputs=platform, outputs=None, |
|
js='(x) => saveItem("platform", x)', show_progress="hidden") |
|
|
|
with gr.Row(): |
|
endpoint = gr.Textbox(label="Endpoint", interactive=True) |
|
endpoint.change(None, inputs=endpoint, outputs=None, |
|
js='(x) => saveItem("endpoint", x)', show_progress="hidden") |
|
|
|
azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True) |
|
azure_deployment.change(None, inputs=azure_deployment, outputs=None, |
|
js='(x) => saveItem("azure_deployment", x)', show_progress="hidden") |
|
|
|
azure_api_version = gr.Textbox(label="Azure API Version", interactive=True) |
|
azure_api_version.change(None, inputs=azure_api_version, outputs=None, |
|
js='(x) => saveItem("azure_api_version", x)', show_progress="hidden") |
|
|
|
with gr.Row(): |
|
api_key_file = gr.File(file_count="single", file_types=["text"], |
|
height=80, label="API Key File") |
|
api_key = gr.Textbox(label="API Key", type="password", interactive=True) |
|
# 注意: 秘密情報をlocalStorageに保存してはならない。他者に秘密情報が盗まれる危険性があるからである。 |
|
|
|
api_key_file.upload(fn=load_api_key, inputs=api_key_file, outputs=api_key, |
|
show_progress="hidden") |
|
api_key_file.clear(fn=lambda: None, inputs=None, outputs=api_key, show_progress="hidden") |
|
|
|
model_name = gr.Textbox(label="model", interactive=True) |
|
model_name.change(None, inputs=model_name, outputs=None, |
|
js='(x) => saveItem("model_name", x)', show_progress="hidden") |
|
|
|
max_tokens = gr.Number(label="Max Tokens", interactive=True, |
|
minimum=0, precision=0, step=1) |
|
max_tokens.change(None, inputs=max_tokens, outputs=None, |
|
js='(x) => saveItem("max_tokens", x)', show_progress="hidden") |
|
|
|
temperature = gr.Slider(label="Temperature", interactive=True, |
|
minimum=0.0, maximum=1.0, step=0.1) |
|
temperature.change(None, inputs=temperature, outputs=None, |
|
js='(x) => saveItem("temperature", x)', show_progress="hidden") |
|
|
|
save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True) |
|
|
|
setting_items = [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens, temperature, save_chat_history_to_url] |
|
reset_button = gr.Button("Reset Settings") |
|
reset_button.click(None, inputs=None, outputs=setting_items, |
|
js="() => resetSettings()", show_progress="hidden") |
|
|
|
with gr.TabItem("Chat"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
pdf_file = gr.File(file_count="single", file_types=[".pdf"], |
|
height=80, label="PDF") |
|
context = gr.Textbox(label="Context", lines=20, |
|
interactive=True, autoscroll=False, show_copy_button=True) |
|
char_counter = gr.Textbox(label="Statistics", value=get_character_count_info(0, 0), |
|
lines=2, max_lines=2, interactive=False, container=True) |
|
|
|
pdf_file.upload(fn=update_context_element, inputs=pdf_file, outputs=[context, char_counter]) |
|
pdf_file.clear(fn=lambda: None, inputs=None, outputs=context, show_progress="hidden") |
|
|
|
context.change(fn=count_characters, inputs=context, outputs=char_counter, show_progress="hidden") |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
CHAT_HISTORY, |
|
elem_id="chatbot", render=False, height=500, show_copy_button=True, |
|
render_markdown=False, likeable=False, layout="bubble", |
|
avatar_images=[None, Path("robot.png")]) |
|
|
|
chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None, |
|
# チャット履歴をクエリパラメータに保存する。 |
|
js=save_or_delete_chat_history, show_progress="hidden") |
|
|
|
save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None, |
|
js=save_or_delete_chat_history, show_progress="hidden") |
|
|
|
chat = gr.ChatInterface(process_prompt, |
|
title="Chat with your PDF", |
|
chatbot=chatbot, |
|
textbox=gr.Textbox( |
|
placeholder="Type a message...", |
|
render=False, container=False, scale=7), |
|
additional_inputs=[context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature], |
|
examples=[['''制約条件に従い、以下の研究論文で提案されている技術や手法について要約してください。 |
|
|
|
# 制約条件 |
|
* 要約者: 大学教授 |
|
* 想定読者: 大学院生 |
|
* 要約結果の言語: 日本語 |
|
* 要約結果の構成: |
|
1. どんな研究であるか |
|
2. 先行研究に比べて優れている点は何か |
|
3. 提案されている技術や手法の重要な点は何か |
|
4. どのような方法で有効であると評価したか |
|
5. 何か議論はあるか |
|
6. 次に読むべき論文は何か |
|
|
|
# 研究論文 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 要約結果'''], ['''制約条件に従い、以下の文書の内容を要約してください。 |
|
|
|
# 制約条件 |
|
* 要約者: 大学教授 |
|
* 想定読者: 大学院生 |
|
* 形式: 箇条書き |
|
* 分量: 20項目 |
|
* 要約結果の言語: 日本語 |
|
|
|
# 文書 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 要約'''], ['''制約条件に従い、以下の文書から情報を抽出してください。 |
|
|
|
# 制約条件 |
|
* 抽出する情報: 課題や問題点について言及している全ての文。一つも見落とさないでください。 |
|
* 出力形式: 箇条書き |
|
* 出力言語: 元の言語の文章と、その日本語訳 |
|
|
|
# 文書 |
|
""" |
|
{context} |
|
""" |
|
|
|
# 抽出結果'''], ["続きを生成してください。"]]) |
|
|
|
app.load(None, inputs=None, outputs=setting_items, |
|
js=js_define_utilities_and_load_settings, show_progress="hidden") |
|
|
|
app.queue().launch() |
|
|
|
main() |
|
</gradio-file> |
|
|
|
|
|
<gradio-file name="robot.png" url="https://raw.githubusercontent.com/sonoisa/misc/main/resources/icons/chatbot_icon.png" /> |
|
</gradio-lite> |
|
|
|
<script language="javascript" src="https://cdn.jsdelivr.net/npm/[email protected]/libs/lz-string.min.js"></script> |
|
<script language="javascript"> |
|
(function () { |
|
|
|
const url = new URL(window.location.href); |
|
|
|
if (url.searchParams.has("history")) { |
|
const compressedHistory = url.searchParams.get("history"); |
|
hist = LZString.decompressFromEncodedURIComponent(compressedHistory); |
|
|
|
const chat_history_element = document.querySelector('gradio-file[name="chat_history.json"]'); |
|
chat_history_element.textContent = hist; |
|
} |
|
})(); |
|
</script> |
|
<script type="module" crossorigin src="https://cdn.jsdelivr.net/npm/@gradio/[email protected]/dist/lite.js"></script> |
|
</body> |
|
</html> |