Ubuntu commited on
Commit
84e90e4
1 Parent(s): b9009d9

Initial Commit

Browse files
Files changed (3) hide show
  1. api_calls.py +42 -0
  2. app.py +189 -0
  3. requirements.txt +3 -0
api_calls.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ API_ENDPOINT = "http://54.254.230.28:8888"
4
+
5
+ # function to call api
6
+ def call_api_stream(api_path, api_params):
7
+ session = requests.Session()
8
+ url = f"{API_ENDPOINT}/{api_path}"
9
+ response = session.post(
10
+ url, json=api_params, headers={"Content-Type": "application/json"},
11
+ stream=True
12
+ )
13
+ return response
14
+
15
+ def call_api(api_path, api_params):
16
+ session = requests.Session()
17
+ url = f"{API_ENDPOINT}/{api_path}"
18
+ response = session.post(
19
+ url, json=api_params, headers={"Content-Type": "application/json"}
20
+ )
21
+ return response.json()
22
+
23
+ def api_rag_qa_chain_demo(openai_model_name, query, year, company_name):
24
+ api_path = "qa/demo"
25
+ api_params = {
26
+ "openai_model_name": openai_model_name,
27
+ "query": query,
28
+ "year": year,
29
+ "company_name": company_name,
30
+ }
31
+ return call_api_stream(api_path, api_params)
32
+
33
+ def api_rag_summ_chain_demo(openai_model_name, query, year, company_name, tone):
34
+ api_path = "qa/waterfee"
35
+ api_params = {
36
+ "openai_model_name": openai_model_name,
37
+ "query": query,
38
+ "year": year,
39
+ "company_name": company_name,
40
+ "tone": tone,
41
+ }
42
+ return call_api_stream(api_path, api_params)
app.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import arrow
2
+ import gradio as gr
3
+ import os
4
+ import re
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from time import sleep
8
+ from tqdm import tqdm
9
+
10
+ from api_calls import *
11
+
12
+ ROOT_DIR = Path(__file__).resolve().parents[0]
13
+
14
+
15
+ def export_to_txt(output):
16
+ today_dt_str = arrow.now(tz="Asia/Taipei").format("YYYYMMDDTHHmmss")
17
+ with open(f"esg_report_summary-{today_dt_str}.txt", "w") as f:
18
+ f.write(output)
19
+ return f"esg_report_summary-{today_dt_str}.txt"
20
+
21
+ def print_like_dislike(x: gr.LikeData):
22
+ print(x.index, x.value, x.liked)
23
+
24
+ def add_text(history, text):
25
+ history = history + [(text, None)]
26
+ return history, gr.Textbox(value="", interactive=False)
27
+
28
+ def esgsumm_exe(openai_model_name, year, company_name, tone):
29
+ query = "根據您提供的相關資訊和偏好語氣,以繁體中文生成一份符合GRI標準的報告草稿。報告將包括每個GRI披露項目的標題、相關公司行為的概要,以及公司的具體措施和效果。"
30
+ response = api_rag_summ_chain_demo(openai_model_name, query, year, company_name, tone)
31
+ full_anwser = ""
32
+ for chunk in response.iter_content(chunk_size=32):
33
+ if chunk:
34
+ try:
35
+ _c = chunk.decode('utf-8')
36
+ except UnicodeDecodeError:
37
+ _c = " "
38
+ full_anwser += _c
39
+ yield full_anwser
40
+ # for character in response:
41
+ # full_text += character
42
+ # yield full_text
43
+
44
+ def esgqabot(history, openai_model_name, year, company_name):
45
+ query = history[-1][0]
46
+ response = api_rag_qa_chain_demo(openai_model_name, query, year, company_name)
47
+ history[-1][1] = ""
48
+ for chunk in response.iter_content(chunk_size=32):
49
+ if chunk:
50
+ try:
51
+ _c = chunk.decode('utf-8')
52
+ except UnicodeDecodeError:
53
+ _c = " "
54
+ history[-1][1] += _c
55
+ yield history
56
+ # for character in response:
57
+ # history[-1][1] += character
58
+ # yield history
59
+
60
+
61
+ css = """
62
+ #center {text-align: center}
63
+ footer {visibility: hidden}
64
+ a {color: rgb(255, 206, 10) !important}
65
+ """
66
+ with gr.Blocks(css=css, theme=gr.themes.Monochrome(neutral_hue="lime")) as demo:
67
+
68
+ gr.HTML("<h1>ESG RAG Playground</h1>", elem_id="center")
69
+ gr.Markdown("Made by `Abao`", elem_id="center")
70
+ gr.Markdown("---")
71
+
72
+ # esgsumm
73
+ with gr.Tab("ESG Report Summarization"):
74
+ gr.HTML("<h2>Report Summarization</h2><p>Summarize report with tone & schema.</p>", elem_id="center")
75
+ with gr.Row():
76
+ with gr.Group():
77
+ gr.Markdown("### Configuration", elem_id="center")
78
+ esgsumm_report_tone = gr.Dropdown(
79
+ label="Tone",
80
+ choices=["富有創意", "中庸", "精確"])
81
+ esgsumm_openai_model_name = gr.Dropdown(
82
+ label="OpenAI Model",
83
+ choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"])
84
+ esgsumm_year = gr.Dropdown(
85
+ label="Year",
86
+ choices=["111", "110", "109"]
87
+ )
88
+ esgsumm_company_name = gr.Dropdown(
89
+ label="Company Name",
90
+ choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"]
91
+ )
92
+ esgsumm_report_gen_button = gr.Button("Generate Report")
93
+
94
+ with gr.Column():
95
+ gr.Markdown("## Generate ESG Summarization", elem_id="center")
96
+ with gr.Accordion("Revise Your Prompt", open=False):
97
+ esgsumm_checkbox_replace = gr.Checkbox(label="Replace with new prompt")
98
+ esgsumm_prompt_tmpl = gr.Textbox(
99
+ label="希望用於本次問答的prompt",
100
+ info="必須使用到的變數:{filtered_data}、{query}",
101
+ value=prompt_dict["qa"],
102
+ interactive=True,
103
+ )
104
+ esgsumm_report_output = gr.Textbox(
105
+ label="Report Output",
106
+ interactive=False,
107
+ scale=4,
108
+ )
109
+ esgsumm_download_btn = gr.Button("Export Summary")
110
+ esgsumm_download_file = gr.File(
111
+ label="Download Summary Text", file_types=[".txt"]
112
+ )
113
+
114
+
115
+ # esgqa
116
+ with gr.Tab("ESG QA"):
117
+ gr.HTML("<h2>ParallelQA (GPT-4 like)</h2><p>Test multiple LLMs at once.</p>", elem_id="center")
118
+ with gr.Row():
119
+ with gr.Group():
120
+ gr.Markdown("### Configuration", elem_id="center")
121
+ esgqa_openai_model_name = gr.Dropdown(
122
+ label="OpenAI Model",
123
+ choices=["gpt-4-turbo-preview", "gpt-3.5-turbo"])
124
+ esgqa_year = gr.Dropdown(
125
+ label="Year",
126
+ choices=["111", "110", "109"]
127
+ )
128
+ esgqa_company_name = gr.Dropdown(
129
+ label="Company Name",
130
+ choices=["台泥", "聯電", "裕融", "大同", "台積電", "鴻海", "中鋼", "中華電信"]
131
+ )
132
+
133
+ with gr.Column():
134
+ gr.Markdown("## Chat with ESGQABot", elem_id="center")
135
+ with gr.Accordion("Revise Your Prompt", open=False):
136
+ esgqa_checkbox_replace = gr.Checkbox(label="Replace with new prompt")
137
+ esgqa_prompt_tmpl = gr.Textbox(
138
+ label="希望用於本次問答的prompt",
139
+ info="必須使用到的變數:{filtered_data}、{query}",
140
+ value=prompt_dict["qa"],
141
+ interactive=True,
142
+ )
143
+ esgqa_chatbot = gr.Chatbot(
144
+ [(None, "我是 ESGQABot\n有什麼能為您服務的嗎?")],
145
+ elem_id="chatbot",
146
+ scale=1,
147
+ height=700,
148
+ bubble_full_width=False
149
+ )
150
+ with gr.Row():
151
+ esgqa_chatbot_input = gr.Textbox(
152
+ scale=4,
153
+ show_label=False,
154
+ placeholder="Enter text and press enter, or upload an image",
155
+ container=False,
156
+ )
157
+ esgqa_chat_btn = gr.Button("💬")
158
+
159
+
160
+ # esgsumm
161
+ esgsumm_report_gen_button.click(
162
+ esgsumm_exe, [esgsumm_openai_model_name, esgsumm_year, esgsumm_company_name, esgsumm_report_tone], esgsumm_report_output
163
+ )
164
+ esgsumm_download_btn.click(
165
+ fn=export_to_txt,
166
+ inputs=[esgsumm_report_output],
167
+ outputs=esgsumm_download_file,
168
+ )
169
+
170
+ # esgqa
171
+ esgqa_chatbot_input.submit(
172
+ add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False
173
+ ).then(
174
+ esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_company_name], esgqa_chatbot, api_name="esgqa_response"
175
+ ).then(
176
+ lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False
177
+ )
178
+ esgqa_chat_btn.click(
179
+ add_text, [esgqa_chatbot, esgqa_chatbot_input], [esgqa_chatbot, esgqa_chatbot_input], queue=False
180
+ ).then(
181
+ esgqabot, [esgqa_chatbot, esgqa_openai_model_name, esgqa_year, esgqa_company_name], esgqa_chatbot, api_name="esgqa_response"
182
+ ).then(
183
+ lambda: gr.Textbox(interactive=True), None, [esgqa_chatbot_input], queue=False
184
+ )
185
+ esgqa_chatbot.like(print_like_dislike, None, None)
186
+
187
+
188
+ if __name__ == "__main__":
189
+ demo.queue().launch(max_threads=10)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ arrow
2
+ pandas
3
+ requests