sonoisa commited on
Commit
130ab22
·
verified ·
1 Parent(s): e5c29a5

Add cost estimation functionality

Browse files
Files changed (1) hide show
  1. index.html +215 -72
index.html CHANGED
@@ -38,6 +38,14 @@ https://opensource.org/license/mit/
38
  #context > label > textarea {
39
  scrollbar-width: thin !important;
40
  }
 
 
 
 
 
 
 
 
41
  </style>
42
  </head>
43
  <body>
@@ -77,6 +85,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
77
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
78
 
79
 
 
80
  import gradio as gr
81
  import base64
82
  import json
@@ -400,6 +409,30 @@ CHAT_TOOLS = [
400
  }
401
  ]
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
  async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
405
  """
@@ -421,21 +454,10 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
421
  Returns:
422
  str: ChatGPTによる生成結果
423
  """
424
-
425
- pages = extract_pages_from_page_tag(context)
426
- if pages:
427
- context = "".join([page.content for page in pages])
428
 
429
  try:
430
- messages = []
431
- for user_message, assistant_message in history:
432
- if user_message is not None and assistant_message is not None:
433
- user_message = user_message.replace("{context}", context)
434
- messages.append({ "role": "user", "content": user_message })
435
- messages.append({ "role": "assistant", "content": assistant_message })
436
-
437
- prompt = prompt.replace("{context}", context)
438
- messages.append({ "role": "user", "content": prompt })
439
 
440
  if platform == "OpenAI":
441
  openai_client = OpenAI(
@@ -468,6 +490,9 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
468
 
469
  response_message = completion.choices[0].message
470
  tool_calls = response_message.tool_calls
 
 
 
471
  if tool_calls:
472
  messages.append(response_message)
473
 
@@ -499,6 +524,7 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
499
  bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
500
  else:
501
  bot_response += "Page not found.\n\n"
 
502
  elif function_name == "load_pages":
503
  # ページ取得
504
  page_numbers = function_args.get("page_numbers")
@@ -523,6 +549,8 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
523
  bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
524
  else:
525
  bot_response += "Page not found.\n\n"
 
 
526
 
527
  yield bot_response + "Generating response. Please wait a moment...\n"
528
  await asyncio.sleep(0.1)
@@ -534,6 +562,8 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
534
  temperature=temperature,
535
  stream=False
536
  )
 
 
537
 
538
  if hasattr(completion, "error"):
539
  raise gr.Error(completion.error["message"])
@@ -553,6 +583,7 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
553
  raise gr.Error(str(e))
554
 
555
 
 
556
  def load_api_key(file_obj):
557
  """
558
  APIキーファイルからAPIキーを読み込む。
@@ -570,6 +601,10 @@ def load_api_key(file_obj):
570
  raise gr.Error(str(e))
571
 
572
 
 
 
 
 
573
  def main():
574
  """
575
  アプリケーションのメイン関数。Gradioインターフェースを設定し、アプリケーションを起動する。
@@ -762,26 +797,26 @@ def main():
762
  with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app:
763
  with gr.Tabs():
764
  with gr.TabItem("Settings"):
765
- with gr.Row():
766
- with gr.Column():
767
- platform = gr.Radio(label="Platform", interactive=True,
768
- choices=["OpenAI", "Azure"], value="OpenAI")
769
- platform.change(None, inputs=platform, outputs=None,
770
- js='(x) =&gt; saveItem("platform", x)', show_progress="hidden")
771
-
772
- with gr.Row():
773
- endpoint = gr.Textbox(label="Endpoint", interactive=True)
774
- endpoint.change(None, inputs=endpoint, outputs=None,
775
- js='(x) =&gt; saveItem("endpoint", x)', show_progress="hidden")
776
-
777
- azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True)
778
- azure_deployment.change(None, inputs=azure_deployment, outputs=None,
779
- js='(x) =&gt; saveItem("azure_deployment", x)', show_progress="hidden")
780
-
781
- azure_api_version = gr.Textbox(label="Azure API Version", interactive=True)
782
- azure_api_version.change(None, inputs=azure_api_version, outputs=None,
783
- js='(x) =&gt; saveItem("azure_api_version", x)', show_progress="hidden")
784
-
785
  with gr.Row():
786
  api_key_file = gr.File(file_count="single", file_types=["text"],
787
  height=80, label="API Key File")
@@ -792,27 +827,27 @@ def main():
792
  show_progress="hidden")
793
  api_key_file.clear(lambda: None, inputs=None, outputs=api_key, show_progress="hidden")
794
 
795
- model_name = gr.Textbox(label="model", interactive=True)
796
- model_name.change(None, inputs=model_name, outputs=None,
797
- js='(x) =&gt; saveItem("model_name", x)', show_progress="hidden")
798
 
799
- max_tokens = gr.Number(label="Max Tokens", interactive=True,
800
- minimum=0, precision=0, step=1)
801
- max_tokens.change(None, inputs=max_tokens, outputs=None,
802
- js='(x) =&gt; saveItem("max_tokens", x)', show_progress="hidden")
803
 
804
- temperature = gr.Slider(label="Temperature", interactive=True,
805
- minimum=0.0, maximum=1.0, step=0.1)
806
- temperature.change(None, inputs=temperature, outputs=None,
807
- js='(x) =&gt; saveItem("temperature", x)', show_progress="hidden")
808
 
809
- save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
810
 
811
- setting_items = [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens,
812
- temperature, save_chat_history_to_url]
813
- reset_button = gr.Button("Reset Settings")
814
- reset_button.click(None, inputs=None, outputs=setting_items,
815
- js="() =&gt; resetSettings()", show_progress="hidden")
816
 
817
  with gr.TabItem("Chat"):
818
  with gr.Row():
@@ -827,18 +862,131 @@ def main():
827
  pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter])
828
  pdf_file.clear(lambda: None, inputs=None, outputs=context, show_progress="hidden")
829
 
830
- (context.change(count_characters, inputs=context, outputs=char_counter, show_progress="hidden")
831
- .then(create_search_engine, inputs=context, outputs=None))
832
-
833
  with gr.Column(scale=2):
834
- chatbot = gr.Chatbot(
835
- CHAT_HISTORY,
836
- elem_id="chatbot", render=False, height=500, show_copy_button=True,
837
- sanitize_html=False, render_markdown=False, likeable=False, layout="bubble",
838
- avatar_images=[None, Path("robot.png")])
839
 
840
- chat_message_textbox = gr.Textbox(placeholder="Type a message...",
841
- render=False, container=False, interactive=True, scale=7)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
 
843
  chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
844
  # チャット履歴をクエリパラメータに保存する。
@@ -847,18 +995,13 @@ def main():
847
  save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
848
  js=save_or_delete_chat_history, show_progress="hidden")
849
 
850
- chat = gr.ChatInterface(process_prompt,
851
- title="Chat with your PDF",
852
- chatbot=chatbot,
853
- textbox=chat_message_textbox,
854
- additional_inputs=[context, platform, endpoint, azure_deployment, azure_api_version, api_key,
855
- model_name, max_tokens, temperature],
856
- examples=None)
857
-
858
- example_title_textbox = gr.Textbox(visible=False, interactive=True)
859
- gr.Examples([[k] for k, v in examples.items()],
860
- inputs=example_title_textbox, outputs=chat_message_textbox,
861
- fn=lambda title: examples[title], run_on_click=True)
862
 
863
  app.load(None, inputs=None, outputs=setting_items,
864
  js=js_define_utilities_and_load_settings, show_progress="hidden")
 
38
  #context > label > textarea {
39
  scrollbar-width: thin !important;
40
  }
41
+
42
+ #cost_info {
43
+ border-style: none !important;
44
+ }
45
+
46
+ #cost_info > label > input {
47
+ background: var(--panel-background-fill) !important;
48
+ }
49
  </style>
50
  </head>
51
  <body>
 
85
  await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
86
 
87
 
88
+ import inspect
89
  import gradio as gr
90
  import base64
91
  import json
 
409
  }
410
  ]
411
 
412
+ CHAT_TOOLS_TOKENS = 139
413
+
414
+
415
+ def get_openai_messages(prompt, history, context):
416
+ global SEARCH_ENGINE
417
+ if SEARCH_ENGINE is not None:
418
+ context = "".join([page.content for page in SEARCH_ENGINE.pages])
419
+
420
+ messages = []
421
+ for user_message, assistant_message in history:
422
+ if user_message is not None and assistant_message is not None:
423
+ user_message = user_message.replace("{context}", context)
424
+ messages.append({ "role": "user", "content": user_message })
425
+ messages.append({ "role": "assistant", "content": assistant_message })
426
+
427
+ prompt = prompt.replace("{context}", context)
428
+ messages.append({ "role": "user", "content": prompt })
429
+
430
+ return messages
431
+
432
+
433
+ actual_total_cost_prompt = 0
434
+ actual_total_cost_completion = 0
435
+
436
 
437
  async def process_prompt(prompt, history, context, platform, endpoint, azure_deployment, azure_api_version, api_key, model_name, max_tokens, temperature):
438
  """
 
454
  Returns:
455
  str: ChatGPTによる生成結果
456
  """
457
+ global actual_total_cost_prompt, actual_total_cost_completion
 
 
 
458
 
459
  try:
460
+ messages = get_openai_messages(prompt, history, context)
 
 
 
 
 
 
 
 
461
 
462
  if platform == "OpenAI":
463
  openai_client = OpenAI(
 
490
 
491
  response_message = completion.choices[0].message
492
  tool_calls = response_message.tool_calls
493
+ actual_total_cost_prompt += completion.usage.prompt_tokens
494
+ actual_total_cost_completion += completion.usage.completion_tokens
495
+
496
  if tool_calls:
497
  messages.append(response_message)
498
 
 
524
  bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
525
  else:
526
  bot_response += "Page not found.\n\n"
527
+
528
  elif function_name == "load_pages":
529
  # ページ取得
530
  page_numbers = function_args.get("page_numbers")
 
549
  bot_response += f'Found page{add_s(found_pages)}: {", ".join([str(page.number) for page in found_pages])}.\n\n'
550
  else:
551
  bot_response += "Page not found.\n\n"
552
+ else:
553
+ raise gr.Error(f"Unknown function calling '{function_name}'.")
554
 
555
  yield bot_response + "Generating response. Please wait a moment...\n"
556
  await asyncio.sleep(0.1)
 
562
  temperature=temperature,
563
  stream=False
564
  )
565
+ actual_total_cost_prompt += completion.usage.prompt_tokens
566
+ actual_total_cost_completion += completion.usage.completion_tokens
567
 
568
  if hasattr(completion, "error"):
569
  raise gr.Error(completion.error["message"])
 
583
  raise gr.Error(str(e))
584
 
585
 
586
+
587
  def load_api_key(file_obj):
588
  """
589
  APIキーファイルからAPIキーを読み込む。
 
601
  raise gr.Error(str(e))
602
 
603
 
604
+ def get_cost_info(prompt_token_count):
605
+ return f"Estimated input cost: {prompt_token_count + CHAT_TOOLS_TOKENS:,} tokens, Actual total input cost: {actual_total_cost_prompt:,} tokens, Actual total output cost: {actual_total_cost_completion:,} tokens"
606
+
607
+
608
  def main():
609
  """
610
  アプリケーションのメイン関数。Gradioインターフェースを設定し、アプリケーションを起動する。
 
797
  with gr.Blocks(theme=gr.themes.Default(), analytics_enabled=False) as app:
798
  with gr.Tabs():
799
  with gr.TabItem("Settings"):
800
+ with gr.Column():
801
+ platform = gr.Radio(label="Platform", interactive=True,
802
+ choices=["OpenAI", "Azure"], value="OpenAI")
803
+ platform.change(None, inputs=platform, outputs=None,
804
+ js='(x) =&gt; saveItem("platform", x)', show_progress="hidden")
805
+
806
+ with gr.Row():
807
+ endpoint = gr.Textbox(label="Endpoint", interactive=True)
808
+ endpoint.change(None, inputs=endpoint, outputs=None,
809
+ js='(x) =&gt; saveItem("endpoint", x)', show_progress="hidden")
810
+
811
+ azure_deployment = gr.Textbox(label="Azure Deployment", interactive=True)
812
+ azure_deployment.change(None, inputs=azure_deployment, outputs=None,
813
+ js='(x) =&gt; saveItem("azure_deployment", x)', show_progress="hidden")
814
+
815
+ azure_api_version = gr.Textbox(label="Azure API Version", interactive=True)
816
+ azure_api_version.change(None, inputs=azure_api_version, outputs=None,
817
+ js='(x) =&gt; saveItem("azure_api_version", x)', show_progress="hidden")
818
+
819
+ with gr.Group():
820
  with gr.Row():
821
  api_key_file = gr.File(file_count="single", file_types=["text"],
822
  height=80, label="API Key File")
 
827
  show_progress="hidden")
828
  api_key_file.clear(lambda: None, inputs=None, outputs=api_key, show_progress="hidden")
829
 
830
+ model_name = gr.Textbox(label="model", interactive=True)
831
+ model_name.change(None, inputs=model_name, outputs=None,
832
+ js='(x) =&gt; saveItem("model_name", x)', show_progress="hidden")
833
 
834
+ max_tokens = gr.Number(label="Max Tokens", interactive=True,
835
+ minimum=0, precision=0, step=1)
836
+ max_tokens.change(None, inputs=max_tokens, outputs=None,
837
+ js='(x) =&gt; saveItem("max_tokens", x)', show_progress="hidden")
838
 
839
+ temperature = gr.Slider(label="Temperature", interactive=True,
840
+ minimum=0.0, maximum=1.0, step=0.1)
841
+ temperature.change(None, inputs=temperature, outputs=None,
842
+ js='(x) =&gt; saveItem("temperature", x)', show_progress="hidden")
843
 
844
+ save_chat_history_to_url = gr.Checkbox(label="Save Chat History to URL", interactive=True)
845
 
846
+ setting_items = [platform, endpoint, azure_deployment, azure_api_version, model_name, max_tokens,
847
+ temperature, save_chat_history_to_url]
848
+ reset_button = gr.Button("Reset Settings")
849
+ reset_button.click(None, inputs=None, outputs=setting_items,
850
+ js="() =&gt; resetSettings()", show_progress="hidden")
851
 
852
  with gr.TabItem("Chat"):
853
  with gr.Row():
 
862
  pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter])
863
  pdf_file.clear(lambda: None, inputs=None, outputs=context, show_progress="hidden")
864
 
 
 
 
865
  with gr.Column(scale=2):
 
 
 
 
 
866
 
867
+ additional_inputs = [context, platform, endpoint, azure_deployment, azure_api_version, api_key,
868
+ model_name, max_tokens, temperature]
869
+
870
+ with gr.Blocks() as chat:
871
+ gr.Markdown(f"# Chat with your PDF")
872
+
873
+ with gr.Column(variant="panel"):
874
+ chatbot = gr.Chatbot(
875
+ CHAT_HISTORY,
876
+ elem_id="chatbot", height=500, show_copy_button=True,
877
+ sanitize_html=False, render_markdown=False, likeable=False, layout="bubble",
878
+ avatar_images=[None, Path("robot.png")])
879
+
880
+ message_state = gr.State()
881
+ chatbot_state = gr.State(chatbot.value) if chatbot.value else gr.State([])
882
+
883
+ with gr.Group():
884
+ with gr.Row():
885
+ message_textbox = gr.Textbox(placeholder="Type a message...",
886
+ container=False, show_label=False, interactive=True, scale=7)
887
+
888
+ submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=150)
889
+ stop_button = gr.Button("Stop", variant="stop", visible=False, scale=1, min_width=150)
890
+
891
+ cost_info = gr.Textbox(elem_id="cost_info", value=get_cost_info(0),
892
+ lines=1, max_lines=1, interactive=False, container=False, elem_classes="cost_info")
893
+
894
+ with gr.Row():
895
+ retry_button = gr.Button("🔄 Retry", variant="secondary", size="sm")
896
+ undo_button = gr.Button("↩️ Undo", variant="secondary", size="sm")
897
+ clear_button = gr.Button("🗑️ Clear", variant="secondary", size="sm")
898
+
899
+ def estimate_message_cost(prompt, history, context):
900
+ token_count = 0
901
+ messages = get_openai_messages(prompt, history, context)
902
+ for message in messages:
903
+ tokens = OPENAI_TOKENIZER.encode(message["content"])
904
+ token_count += len(tokens)
905
+
906
+ return gr.update(value=get_cost_info(token_count))
907
+
908
+ message_textbox.change(estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden")
909
+
910
+ example_title_textbox = gr.Textbox(visible=False, interactive=True)
911
+ gr.Examples([[k] for k, v in examples.items()],
912
+ inputs=example_title_textbox, outputs=message_textbox,
913
+ fn=lambda title: examples[title], run_on_click=True)
914
+
915
+ def append_message_to_history(message, history):
916
+ history.append([message, None])
917
+ return history, history
918
+
919
+ def undo_chat(history):
920
+ if history:
921
+ message, _ = history.pop()
922
+ message = message or ""
923
+ else:
924
+ message = ""
925
+ return history, history, message
926
+
927
+ async def submit_message(message, history_with_input, *args):
928
+ history = history_with_input[:-1]
929
+ inputs = [message, history]
930
+ inputs.extend(args)
931
+
932
+ generator = process_prompt(*inputs)
933
+
934
+ try:
935
+ first_response = await gr.utils.async_iteration(generator)
936
+ update = history + [[message, first_response]]
937
+ yield update, update
938
+ except StopIteration:
939
+ update = history + [[message, None]]
940
+ yield update, update
941
+
942
+ async for response in generator:
943
+ update = history + [[message, response]]
944
+ yield update, update
945
+
946
+ submit_triggers = [message_textbox.submit, submit_button.click]
947
+
948
+ submit_event = gr.events.on(submit_triggers, lambda message: ("", message),
949
+ inputs=[message_textbox], outputs=[message_textbox, message_state], queue=False
950
+ ).then(
951
+ append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False
952
+ ).then(
953
+ submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state]
954
+ ).then(
955
+ estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
956
+ )
957
+
958
+ for submit_trigger in submit_triggers:
959
+ submit_trigger(lambda: (gr.update(visible=False), gr.update(visible=True)),
960
+ inputs=None, outputs=[submit_button, stop_button], queue=False)
961
+ submit_event.then(lambda: (gr.update(visible=True), gr.update(visible=False)),
962
+ inputs=None, outputs=[submit_button, stop_button], queue=False)
963
+
964
+ stop_button.click(None, inputs=None, outputs=None, cancels=submit_event)
965
+
966
+ retry_button.click(
967
+ undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False
968
+ ).then(
969
+ append_message_to_history, inputs=[message_state, chatbot_state], outputs=[chatbot, chatbot_state], queue=False
970
+ ).then(
971
+ submit_message, inputs=[message_state, chatbot_state] + additional_inputs, outputs=[chatbot, chatbot_state]
972
+ ).then(
973
+ estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
974
+ )
975
+
976
+ undo_button.click(
977
+ undo_chat, inputs=[chatbot_state], outputs=[chatbot, chatbot_state, message_state], queue=False
978
+ ).then(
979
+ lambda message: message, inputs=message_state, outputs=message_textbox, queue=False
980
+ ).then(
981
+ estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
982
+ )
983
+
984
+ clear_button.click(
985
+ lambda: ([], [], None), inputs=None, outputs=[chatbot, chatbot_state, message_state],
986
+ queue=False
987
+ ).then(
988
+ estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
989
+ )
990
 
991
  chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
992
  # チャット履歴をクエリパラメータに保存する。
 
995
  save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
996
  js=save_or_delete_chat_history, show_progress="hidden")
997
 
998
+ context.change(
999
+ count_characters, inputs=context, outputs=char_counter, show_progress="hidden"
1000
+ ).then(
1001
+ create_search_engine, inputs=context, outputs=None
1002
+ ).then(
1003
+ estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden"
1004
+ )
 
 
 
 
 
1005
 
1006
  app.load(None, inputs=None, outputs=setting_items,
1007
  js=js_define_utilities_and_load_settings, show_progress="hidden")