Spaces:
Build error
Build error
Ubuntu
commited on
Commit
•
84e90e4
1
Parent(s):
b9009d9
Initial Commit
Browse files- api_calls.py +42 -0
- app.py +189 -0
- requirements.txt +3 -0
api_calls.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
API_ENDPOINT = "http://54.254.230.28:8888"
|
4 |
+
|
5 |
+
# function to call api
|
6 |
+
def call_api_stream(api_path, api_params):
|
7 |
+
session = requests.Session()
|
8 |
+
url = f"{API_ENDPOINT}/{api_path}"
|
9 |
+
response = session.post(
|
10 |
+
url, json=api_params, headers={"Content-Type": "application/json"},
|
11 |
+
stream=True
|
12 |
+
)
|
13 |
+
return response
|
14 |
+
|
15 |
+
def call_api(api_path, api_params):
|
16 |
+
session = requests.Session()
|
17 |
+
url = f"{API_ENDPOINT}/{api_path}"
|
18 |
+
response = session.post(
|
19 |
+
url, json=api_params, headers={"Content-Type": "application/json"}
|
20 |
+
)
|
21 |
+
return response.json()
|
22 |
+
|
23 |
+
def api_rag_qa_chain_demo(openai_model_name, query, year, company_name):
|
24 |
+
api_path = "qa/demo"
|
25 |
+
api_params = {
|
26 |
+
"openai_model_name": openai_model_name,
|
27 |
+
"query": query,
|
28 |
+
"year": year,
|
29 |
+
"company_name": company_name,
|
30 |
+
}
|
31 |
+
return call_api_stream(api_path, api_params)
|
32 |
+
|
33 |
+
def api_rag_summ_chain_demo(openai_model_name, query, year, company_name, tone):
|
34 |
+
api_path = "qa/waterfee"
|
35 |
+
api_params = {
|
36 |
+
"openai_model_name": openai_model_name,
|
37 |
+
"query": query,
|
38 |
+
"year": year,
|
39 |
+
"company_name": company_name,
|
40 |
+
"tone": tone,
|
41 |
+
}
|
42 |
+
return call_api_stream(api_path, api_params)
|
app.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import arrow
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import pandas as pd
|
6 |
+
from pathlib import Path
|
7 |
+
from time import sleep
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from api_calls import *
|
11 |
+
|
12 |
+
ROOT_DIR = Path(__file__).resolve().parents[0]
|
13 |
+
|
14 |
+
|
15 |
+
def export_to_txt(output):
|
16 |
+
today_dt_str = arrow.now(tz="Asia/Taipei").format("YYYYMMDDTHHmmss")
|
17 |
+
with open(f"esg_report_summary-{today_dt_str}.txt", "w") as f:
|
18 |
+
f.write(output)
|
19 |
+
return f"esg_report_summary-{today_dt_str}.txt"
|
20 |
+
|
21 |
+
def print_like_dislike(x: gr.LikeData):
|
22 |
+
print(x.index, x.value, x.liked)
|
23 |
+
|
24 |
+
def add_text(history, text):
|
25 |
+
history = history + [(text, None)]
|
26 |
+
return history, gr.Textbox(value="", interactive=False)
|
27 |
+
|
28 |
+
def esgsumm_exe(openai_model_name, year, company_name, tone):
|
29 |
+
query = "根據您提供的相關資訊和偏好語氣,以繁體中文生成一份符合GRI標準的報告草稿。報告將包括每個GRI披露項目的標題、相關公司行為的概要,以及公司的具體措施和效果。"
|
30 |
+
response = api_rag_summ_chain_demo(openai_model_name, query, year, company_name, tone)
|
31 |
+
full_anwser = ""
|
32 |
+
for chunk in response.iter_content(chunk_size=32):
|
33 |
+
if chunk:
|
34 |
+
try:
|
35 |
+
_c = chunk.decode('utf-8')
|
36 |
+
except UnicodeDecodeError:
|
37 |
+
_c = " "
|
38 |
+
full_anwser += _c
|
39 |
+
yield full_anwser
|
40 |
+
# for character in response:
|
41 |
+
# full_text += character
|
42 |
+
# yield full_text
|
43 |
+
|
44 |
+
def esgqabot(history, openai_model_name, year, company_name):
|
45 |
+
query = history[-1][0]
|
46 |
+
response = api_rag_qa_chain_demo(openai_model_name, query, year, company_name)
|
47 |
+
history[-1][1] = ""
|
48 |
+
for chunk in response.iter_content(chunk_size=32):
|
49 |
+
if chunk:
|
50 |
+
try:
|
51 |
+
_c = chunk.decode('utf-8')
|
52 |
+
except UnicodeDecodeError:
|
53 |
+
_c = " "
|
54 |
+
history[-1][1] += _c
|
55 |
+
yield history
|
56 |
+
# for character in response:
|
57 |
+
# history[-1][1] += character
|
58 |
+
# yield history
|
59 |
+
|
60 |
+
|
61 |
+
css = """
|
62 |
+
#center {text-align: center}
|
63 |
+
footer {visibility: hidden}
|
64 |
+
a {color: rgb(255, 206, 10) !important}
|
65 |
+
"""
|
66 |
+
with gr.Blocks(css=css, theme=gr.themes.Monochrome(neutral_hue="lime")) as demo:
|
67 |
+
|
68 |
+
gr.HTML("<h1>ESG RAG Playground</h1>", elem_id="center")
|
69 |
+
gr.Markdown("Made by `Abao`", elem_id="center")
|
70 |
+
gr.Markdown("---")
|
71 |
+
|
72 |
+
# esgsumm
|
73 |
+
with gr.Tab("ESG Report Summarization"):
|
74 |
+
gr.HTML("<h2>Report Summarization</h2><p>Summarize report with tone & schema.</p>", elem_id="center")
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Group():
|
77 |
+
gr.Markdown("### Configuration", elem_id="center")
|
78 |
+
esgsumm_report_tone = gr.Dropdown(
|
79 |
+
label="Tone",
|
80 |
+
choices=["富有創意", "中庸", "精確"])
|
81 |
+
esgsumm_openai_model_name = gr.Dropdown(
|
82 |
+
label="OpenAI Model",
|
83 |
+
choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"])
|
84 |
+
esgsumm_year = gr.Dropdown(
|
85 |
+
label="Year",
|
86 |
+
choices=["111", "110", "109"]
|
87 |
+
)
|
88 |
+
esgsumm_company_name = gr.Dropdown(
|
89 |
+
label="Company Name",
|
90 |
+
choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"]
|
91 |
+
)
|
92 |
+
esgsumm_report_gen_button = gr.Button("Generate Report")
|
93 |
+
|
94 |
+
with gr.Column():
|
95 |
+
gr.Markdown("## Generate ESG Summarization", elem_id="center")
|
96 |
+
with gr.Accordion("Revise Your Prompt", open=False):
|
97 |
+
esgsumm_checkbox_replace = gr.Checkbox(label="Replace with new prompt")
|
98 |
+
esgsumm_prompt_tmpl = gr.Textbox(
|
99 |
+
label="希望用於本次問答的prompt",
|
100 |
+
info="必須使用到的變數:{filtered_data}、{query}",
|
101 |
+
value=prompt_dict["qa"],
|
102 |
+
interactive=True,
|
103 |
+
)
|
104 |
+
esgsumm_report_output = gr.Textbox(
|
105 |
+
label="Report Output",
|
106 |
+
interactive=False,
|
107 |
+
scale=4,
|
108 |
+
)
|
109 |
+
esgsumm_download_btn = gr.Button("Export Summary")
|
110 |
+
esgsumm_download_file = gr.File(
|
111 |
+
label="Download Summary Text", file_types=[".txt"]
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
# esgqa
|
116 |
+
with gr.Tab("ESG QA"):
|
117 |
+
gr.HTML("<h2>ParallelQA (GPT-4 like)</h2><p>Test multiple LLMs at once.</p>", elem_id="center")
|
118 |
+
with gr.Row():
|
119 |
+
with gr.Group():
|
120 |
+
gr.Markdown("### Configuration", elem_id="center")
|
121 |
+
esgqa_openai_model_name = gr.Dropdown(
|
122 |
+
label="OpenAI Model",
|
123 |
+
choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"])
|
124 |
+
esgqa_year = gr.Dropdown(
|
125 |
+
label="Year",
|
126 |
+
choices=["111", "110", "109"]
|
127 |
+
)
|
128 |
+
esgqa_company_name = gr.Dropdown(
|
129 |
+
label="Company Name",
|
130 |
+
choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"]
|
131 |
+
)
|
132 |
+
|
133 |
+
with gr.Column():
|
134 |
+
gr.Markdown("## Chat with ESGQABot", elem_id="center")
|
135 |
+
with gr.Accordion("Revise Your Prompt", open=False):
|
136 |
+
esgqa_checkbox_replace = gr.Checkbox(label="Replace with new prompt")
|
137 |
+
esgqa_prompt_tmpl = gr.Textbox(
|
138 |
+
label="希望用於本次問答的prompt",
|
139 |
+
info="必須使用到的變數:{filtered_data}、{query}",
|
140 |
+
value=prompt_dict["qa"],
|
141 |
+
interactive=True,
|
142 |
+
)
|
143 |
+
esgqa_chatbot = gr.Chatbot(
|
144 |
+
[(None, "我是 ESGQABot\n有什麼能為您服務的嗎?")],
|
145 |
+
elem_id="chatbot",
|
146 |
+
scale=1,
|
147 |
+
height=700,
|
148 |
+
bubble_full_width=False
|
149 |
+
)
|
150 |
+
with gr.Row():
|
151 |
+
esgqa_chatbot_input = gr.Textbox(
|
152 |
+
scale=4,
|
153 |
+
show_label=False,
|
154 |
+
placeholder="Enter text and press enter, or upload an image",
|
155 |
+
container=False,
|
156 |
+
)
|
157 |
+
esgqa_chat_btn = gr.Button("💬")
|
158 |
+
|
159 |
+
|
160 |
+
# esgsumm
|
161 |
+
esgsumm_report_gen_button.click(
|
162 |
+
esgsumm_exe, [esgsumm_openai_model_name, esgsumm_year, esgsumm_company_name, esgsumm_report_tone], esgsumm_report_output
|
163 |
+
)
|
164 |
+
esgsumm_download_btn.click(
|
165 |
+
fn=export_to_txt,
|
166 |
+
inputs=[esgsumm_report_output],
|
167 |
+
outputs=esgsumm_download_file,
|
168 |
+
)
|
169 |
+
|
170 |
+
# esgqa
|
171 |
+
esgqa_chatbot_input.submit(
|
172 |
+
add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False
|
173 |
+
).then(
|
174 |
+
esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_company_name], esgqa_chatbot, api_name="esgqa_response"
|
175 |
+
).then(
|
176 |
+
lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False
|
177 |
+
)
|
178 |
+
esgqa_chat_btn.click(
|
179 |
+
add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False
|
180 |
+
).then(
|
181 |
+
esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_company_name], esgqa_chatbot, api_name="esgqa_response"
|
182 |
+
).then(
|
183 |
+
lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False
|
184 |
+
)
|
185 |
+
esgqa_chatbot.like(print_like_dislike, None, None)
|
186 |
+
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
demo.queue().launch(max_threads=10)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
arrow
|
2 |
+
pandas
|
3 |
+
requests
|