Spaces:
Runtime error
Runtime error
stevengrove
commited on
Commit
•
4ca98ba
1
Parent(s):
5f71fb3
add prompt support
Browse files- app.py +206 -116
- prompts/interview.json +75 -0
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import re
|
|
|
2 |
import argparse
|
3 |
|
4 |
import openai
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
|
8 |
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
|
@@ -11,113 +13,176 @@ USER_FORMAT = """Interview Dialogues:
|
|
11 |
{input_txt}
|
12 |
|
13 |
Please select the rounds containing one of following tags: {pos_tags}.
|
14 |
-
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
if
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
117 |
if __name__ == '__main__':
|
118 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
119 |
args = parser.parse_args()
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
with gr.Blocks() as demo:
|
122 |
with gr.Row():
|
123 |
with gr.Column(scale=0.3):
|
@@ -128,21 +193,39 @@ if __name__ == '__main__':
|
|
128 |
elem_id='api_key_textbox',
|
129 |
placeholder='Enter your OPENAI API Key')
|
130 |
with gr.Row():
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
elem_id='
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
with gr.Row():
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
with gr.Row():
|
144 |
input_txt = gr.Textbox(
|
145 |
lines=4,
|
|
|
146 |
label='Input',
|
147 |
elem_id='input_textbox',
|
148 |
placeholder='Enter text and press submit')
|
@@ -152,17 +235,24 @@ if __name__ == '__main__':
|
|
152 |
clear = gr.Button('Clear')
|
153 |
with gr.Column(scale=0.7):
|
154 |
output_txt = gr.Textbox(
|
|
|
155 |
label='Output',
|
156 |
elem_id='output_textbox')
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
158 |
submit.click(
|
159 |
-
|
160 |
-
[api_key,
|
|
|
161 |
[output_txt])
|
162 |
clear.click(
|
163 |
lambda: ['', '', ''],
|
164 |
None,
|
165 |
-
|
166 |
|
167 |
demo.queue(concurrency_count=6)
|
168 |
demo.launch()
|
|
|
1 |
import re
|
2 |
+
import json
|
3 |
import argparse
|
4 |
|
5 |
import openai
|
6 |
import gradio as gr
|
7 |
+
from functools import partial
|
8 |
|
9 |
|
10 |
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
|
|
|
13 |
{input_txt}
|
14 |
|
15 |
Please select the rounds containing one of following tags: {pos_tags}.
|
16 |
+
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
|
17 |
+
|
18 |
+
|
19 |
+
class GPT4News():
|
20 |
+
|
21 |
+
def __init__(self, prompt_formats):
|
22 |
+
self.name2prompt = {x['name']: x for x in prompt_formats}
|
23 |
+
|
24 |
+
def preprocess(self, function_name, input_txt):
|
25 |
+
max_length = self.name2prompt[function_name]['split_length']
|
26 |
+
max_convs = self.name2prompt[function_name]['split_round']
|
27 |
+
|
28 |
+
input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt)
|
29 |
+
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
|
30 |
+
input_txt = speaker_pattern.split(input_txt)
|
31 |
+
input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
|
32 |
+
|
33 |
+
conversations = []
|
34 |
+
for idx, txt in enumerate(input_txt):
|
35 |
+
if speaker_pattern.match(txt):
|
36 |
+
if idx < len(input_txt) - 1:
|
37 |
+
if not speaker_pattern.match(input_txt[idx + 1]):
|
38 |
+
conv = [txt, input_txt[idx + 1]]
|
39 |
+
else:
|
40 |
+
conv = [txt, '']
|
41 |
+
while len(''.join(conv)) > max_length:
|
42 |
+
pruned_len = max_length - len(''.join(conv[0]))
|
43 |
+
pruned_conv = [txt, conv[1][:pruned_len]]
|
44 |
+
conversations.append(pruned_conv)
|
45 |
+
conv = [txt, conv[-1][pruned_len:]]
|
46 |
+
conversations.append(conv)
|
47 |
+
|
48 |
+
input_txt_list = ['']
|
49 |
+
for conv in conversations:
|
50 |
+
conv_length = len(''.join(conv))
|
51 |
+
if len(input_txt_list[-1]) + conv_length >= max_length:
|
52 |
+
input_txt_list.append('')
|
53 |
+
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
|
54 |
+
input_txt_list.append('')
|
55 |
+
input_txt_list[-1] += ''.join(conv)
|
56 |
+
|
57 |
+
processed_txt_list = []
|
58 |
+
for input_txt in input_txt_list:
|
59 |
+
input_txt = ''.join(input_txt)
|
60 |
+
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
|
61 |
+
processed_txt_list.append(input_txt.strip())
|
62 |
+
return processed_txt_list
|
63 |
+
|
64 |
+
def chatgpt(self, messages, temperature=0.0):
|
65 |
+
try:
|
66 |
+
completion = openai.ChatCompletion.create(
|
67 |
+
model="gpt-3.5-turbo",
|
68 |
+
messages=messages,
|
69 |
+
temperature=temperature
|
70 |
+
)
|
71 |
+
return completion.choices[0].message.content
|
72 |
+
except Exception as err:
|
73 |
+
print(err)
|
74 |
+
return self.chatgpt(messages, temperature)
|
75 |
+
|
76 |
+
def llm(self, function_name, temperature, **kwargs):
|
77 |
+
prompt = self.name2prompt[function_name]
|
78 |
+
user_kwargs = {key: kwargs[key] for key in prompt['user_keys']}
|
79 |
+
user = prompt['user'].format(**user_kwargs)
|
80 |
+
system_kwargs = {key: kwargs[key] for key in prompt['system_keys']}
|
81 |
+
system = prompt['system'].format(**system_kwargs)
|
82 |
+
messages = [
|
83 |
+
{'role': 'system',
|
84 |
+
'content': system},
|
85 |
+
{'role': 'user',
|
86 |
+
'content': user}]
|
87 |
+
response = self.chatgpt(messages, temperature=temperature)
|
88 |
+
print(f'SYSTEM:\n\n{system}')
|
89 |
+
print(f'USER:\n\n{user}')
|
90 |
+
print(f'RESPONSE:\n\n{response}')
|
91 |
+
return response
|
92 |
+
|
93 |
+
def translate(self, txt, output_lang):
|
94 |
+
if output_lang == 'English':
|
95 |
+
return txt
|
96 |
+
system = 'Translate the following text into {}:\n\n{}'.format(
|
97 |
+
output_lang, txt)
|
98 |
+
messages = [{'role': 'system', 'content': system}]
|
99 |
+
response = self.chatgpt(messages)
|
100 |
+
print(f'SYSTEM:\n\n{system}')
|
101 |
+
print(f'RESPONSE:\n\n{response}')
|
102 |
+
return response
|
103 |
+
|
104 |
+
def postprocess(self, function_name, input_txt, output_txt_list,
|
105 |
+
output_lang):
|
106 |
+
if not self.name2prompt[function_name]['post_filter']:
|
107 |
+
output_txt = '\n\n'.join(output_txt_list)
|
108 |
+
output_txt = self.translate(output_txt, output_lang)
|
109 |
+
return output_txt
|
110 |
+
|
111 |
+
speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
|
112 |
+
output_txt = []
|
113 |
+
for txt in output_txt_list:
|
114 |
+
if len(speaker_pattern.findall(txt)) > 0:
|
115 |
+
output_txt.append(txt)
|
116 |
+
output_txt = ''.join(output_txt)
|
117 |
+
speakers = set(speaker_pattern.findall(input_txt))
|
118 |
+
output_txt = speaker_pattern.split(output_txt)
|
119 |
+
|
120 |
+
results = []
|
121 |
+
for idx, txt in enumerate(output_txt):
|
122 |
+
if speaker_pattern.match(txt):
|
123 |
+
if txt not in speakers:
|
124 |
+
continue
|
125 |
+
if idx < len(output_txt) - 1:
|
126 |
+
if not speaker_pattern.match(output_txt[idx + 1]):
|
127 |
+
res = txt + output_txt[idx + 1]
|
128 |
+
else:
|
129 |
+
res = txt
|
130 |
+
res = self.translate(res, output_lang)
|
131 |
+
results.append(res.strip())
|
132 |
+
return '\n\n'.join(results)
|
133 |
+
|
134 |
+
def __call__(self, api_key, function_name, temperature, output_lang,
|
135 |
+
input_txt, tags):
|
136 |
+
if api_key is None or api_key == '':
|
137 |
+
return 'OPENAI API Key is not set.'
|
138 |
+
if function_name is None or function_name == '':
|
139 |
+
return 'Function is not selected.'
|
140 |
+
openai.api_key = api_key
|
141 |
+
input_txt_list = self.preprocess(function_name, input_txt)
|
142 |
+
input_txt = '\n'.join(input_txt_list)
|
143 |
+
output_txt_list = []
|
144 |
+
for txt in input_txt_list:
|
145 |
+
llm_kwargs = dict(input_txt=txt,
|
146 |
+
tags=tags)
|
147 |
+
output_txt = self.llm(function_name, temperature, **llm_kwargs)
|
148 |
+
output_txt_list.append(output_txt)
|
149 |
+
output_txt = self.postprocess(
|
150 |
+
function_name, input_txt, output_txt_list, output_lang)
|
151 |
+
return output_txt
|
152 |
+
|
153 |
+
@property
|
154 |
+
def function_names(self):
|
155 |
+
return self.name2prompt.keys()
|
156 |
+
|
157 |
+
|
158 |
+
def function_name_select_callback(componments, name2prompt, function_name):
|
159 |
+
prompt = name2prompt[function_name]
|
160 |
+
user_keys = prompt['user_keys']
|
161 |
+
result = []
|
162 |
+
for comp in componments:
|
163 |
+
result.append(gr.update(visible=comp in user_keys))
|
164 |
+
return result
|
165 |
|
166 |
|
167 |
if __name__ == '__main__':
|
168 |
parser = argparse.ArgumentParser()
|
169 |
+
parser.add_argument('--prompt', type=str, default='prompts/interview.json',
|
170 |
+
help='path to the prompt file')
|
171 |
+
parser.add_argument('--temperature', type=float, default='0.7',
|
172 |
+
help='temperature for the llm model')
|
173 |
args = parser.parse_args()
|
174 |
|
175 |
+
prompt_formats = json.load(open(args.prompt, 'r'))
|
176 |
+
gpt4news = GPT4News(prompt_formats)
|
177 |
+
|
178 |
+
languages = ['Arabic', 'Bengali', 'Chinese (Simplified)',
|
179 |
+
'Chinese (Traditional)', 'Dutch', 'English', 'French',
|
180 |
+
'German', 'Hindi', 'Italian', 'Japanese', 'Korean',
|
181 |
+
'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish',
|
182 |
+
'Urdu']
|
183 |
+
default_func = sorted(gpt4news.function_names)[0]
|
184 |
+
default_user_keys = gpt4news.name2prompt[default_func]['user_keys']
|
185 |
+
|
186 |
with gr.Blocks() as demo:
|
187 |
with gr.Row():
|
188 |
with gr.Column(scale=0.3):
|
|
|
193 |
elem_id='api_key_textbox',
|
194 |
placeholder='Enter your OPENAI API Key')
|
195 |
with gr.Row():
|
196 |
+
function_name = gr.Dropdown(
|
197 |
+
sorted(gpt4news.function_names),
|
198 |
+
value=default_func,
|
199 |
+
elem_id='function_dropdown',
|
200 |
+
label='Function',
|
201 |
+
info='choose a function to run')
|
202 |
+
with gr.Row():
|
203 |
+
output_lang = gr.Dropdown(
|
204 |
+
languages,
|
205 |
+
value='English',
|
206 |
+
elem_id='output_lang_dropdown',
|
207 |
+
label='Output Language',
|
208 |
+
info='choose a language to output')
|
209 |
with gr.Row():
|
210 |
+
temperature = gr.Slider(
|
211 |
+
minimum=0.0,
|
212 |
+
maximum=1.0,
|
213 |
+
value=args.temperature,
|
214 |
+
step=0.1,
|
215 |
+
interactive=True,
|
216 |
+
label='Temperature',
|
217 |
+
info='higher temperature means more creative')
|
218 |
+
with gr.Row():
|
219 |
+
tags = gr.Textbox(
|
220 |
+
lines=1,
|
221 |
+
visible='tags' in default_user_keys,
|
222 |
+
label='Tags',
|
223 |
+
elem_id='tags_textbox',
|
224 |
+
placeholder='Enter tags split by semicolon')
|
225 |
with gr.Row():
|
226 |
input_txt = gr.Textbox(
|
227 |
lines=4,
|
228 |
+
visible='input_txt' in default_user_keys,
|
229 |
label='Input',
|
230 |
elem_id='input_textbox',
|
231 |
placeholder='Enter text and press submit')
|
|
|
235 |
clear = gr.Button('Clear')
|
236 |
with gr.Column(scale=0.7):
|
237 |
output_txt = gr.Textbox(
|
238 |
+
lines=8,
|
239 |
label='Output',
|
240 |
elem_id='output_textbox')
|
241 |
+
function_name.select(
|
242 |
+
partial(function_name_select_callback, ['input_txt', 'tags'],
|
243 |
+
gpt4news.name2prompt),
|
244 |
+
[function_name],
|
245 |
+
[input_txt, tags]
|
246 |
+
)
|
247 |
submit.click(
|
248 |
+
gpt4news,
|
249 |
+
[api_key, function_name, temperature, output_lang,
|
250 |
+
input_txt, tags],
|
251 |
[output_txt])
|
252 |
clear.click(
|
253 |
lambda: ['', '', ''],
|
254 |
None,
|
255 |
+
tags, input_txt)
|
256 |
|
257 |
demo.queue(concurrency_count=6)
|
258 |
demo.launch()
|
prompts/interview.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"name": "searching",
|
4 |
+
"system": "You are a tool for filtering out paragraphs from the interview dialogues given by user.",
|
5 |
+
"system_keys": [],
|
6 |
+
"user": "Interview Dialogues:\n{input_txt}\n\nPlease select the rounds containing one of following tags: {tags}. Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like \"speaker_name speaking_time: tag, reason\".",
|
7 |
+
"user_keys": [
|
8 |
+
"input_txt",
|
9 |
+
"tags"
|
10 |
+
],
|
11 |
+
"post_filter": true,
|
12 |
+
"split_length": 4000,
|
13 |
+
"split_round": 4
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "proofreading",
|
17 |
+
"system": "You are a proofreading tool used to improve the wording, grammar, and logical issues in a given interview record. Note that the output should maintain the original meaning, as well as keeping the speaker's name and interview time unchanged.",
|
18 |
+
"system_keys": [],
|
19 |
+
"user": "{input_txt}\n\n------\nPlease proofread the interview record and output the improved version. Note that the output should maintain the original meaning, as well as keeping the speaker's name and interview time unchanged.",
|
20 |
+
"user_keys": [
|
21 |
+
"input_txt"
|
22 |
+
],
|
23 |
+
"post_filter": true,
|
24 |
+
"split_length": 4000,
|
25 |
+
"split_round": 4
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"name": "summarization",
|
29 |
+
"system": "You are a text summarization tool used to summarize the meaning of each round of conversation in an interview record.",
|
30 |
+
"system_keys": [],
|
31 |
+
"user": "{input_txt}\n\n------\nPlease summarize the meaning of each round of conversation in an interview record. Note that the output should be concise and contains key information. The output should be like \"speaker_name speaking_time: summarization\"",
|
32 |
+
"user_keys": [
|
33 |
+
"input_txt"
|
34 |
+
],
|
35 |
+
"post_filter": true,
|
36 |
+
"split_length": 4000,
|
37 |
+
"split_round": 4
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"name": "summary to news",
|
41 |
+
"system": "You are a news writer who writes news articles based on the given summary of interview records.",
|
42 |
+
"system_keys": [],
|
43 |
+
"user": "{input_txt}\n\n------\nPlease write a news article based on the given summary of interview records.",
|
44 |
+
"user_keys": [
|
45 |
+
"input_txt"
|
46 |
+
],
|
47 |
+
"post_filter": false,
|
48 |
+
"split_length": 10000000,
|
49 |
+
"split_round": 10000
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"name": "summary to twitter",
|
53 |
+
"system": "You are a Twitter author who writes tweets based on the given summary of interview records.",
|
54 |
+
"system_keys": [],
|
55 |
+
"user": "{input_txt}\n\n------\nPlease writes a tweet based on the given summary of interview records. Note that the number of words in the output MUST be less than 140.",
|
56 |
+
"user_keys": [
|
57 |
+
"input_txt"
|
58 |
+
],
|
59 |
+
"post_filter": false,
|
60 |
+
"split_length": 10000000,
|
61 |
+
"split_round": 10000
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "summary to weibo",
|
65 |
+
"system": "You are a Weibo author who writes eye-catching short articles based on the given summary of interview records.",
|
66 |
+
"system_keys": [],
|
67 |
+
"user": "{input_txt}\n\n------\nPlease write an eye-catching short article based on the given summary of interview records. Note that the number of words in the output MUST be less than 140.",
|
68 |
+
"user_keys": [
|
69 |
+
"input_txt"
|
70 |
+
],
|
71 |
+
"post_filter": false,
|
72 |
+
"split_length": 10000000,
|
73 |
+
"split_round": 10000
|
74 |
+
}
|
75 |
+
]
|