stevengrove commited on
Commit
4ca98ba
1 Parent(s): 5f71fb3

add prompt support

Browse files
Files changed (2) hide show
  1. app.py +206 -116
  2. prompts/interview.json +75 -0
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import re
 
2
  import argparse
3
 
4
  import openai
5
  import gradio as gr
 
6
 
7
 
8
  SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
@@ -11,113 +13,176 @@ USER_FORMAT = """Interview Dialogues:
11
  {input_txt}
12
 
13
  Please select the rounds containing one of following tags: {pos_tags}.
14
- Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
15
-
16
-
17
- def preprocess(input_txt, max_length=4000, max_convs=4):
18
- speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
19
- input_txt = speaker_pattern.split(input_txt)
20
- input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
21
-
22
- conversations = []
23
- for idx, txt in enumerate(input_txt):
24
- if txt.startswith('说话人'):
25
- if idx < len(input_txt) - 1:
26
- if not input_txt[idx + 1].startswith('说话人'):
27
- conv = [txt, input_txt[idx + 1]]
28
- else:
29
- conv = [txt, '']
30
- while len(''.join(conv)) > max_length:
31
- pruned_len = max_length - len(''.join(conv[0]))
32
- pruned_conv = [txt, conv[1][:pruned_len]]
33
- conversations.append(pruned_conv)
34
- conv = [txt, conv[-1][pruned_len:]]
35
- conversations.append(conv)
36
-
37
- input_txt_list = ['']
38
- for conv in conversations:
39
- conv_length = len(''.join(conv))
40
- if len(input_txt_list[-1]) + conv_length >= max_length:
41
- input_txt_list.append('')
42
- elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
43
- input_txt_list.append('')
44
- input_txt_list[-1] += ''.join(conv)
45
-
46
- processed_txt_list = []
47
- for input_txt in input_txt_list:
48
- input_txt = ''.join(input_txt)
49
- input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
50
- processed_txt_list.append(input_txt.strip())
51
- return processed_txt_list
52
-
53
-
54
- def chatgpt(messages, temperature=0.0):
55
- try:
56
- completion = openai.ChatCompletion.create(
57
- model="gpt-3.5-turbo",
58
- messages=messages,
59
- temperature=temperature
60
- )
61
- return completion.choices[0].message.content
62
- except Exception as err:
63
- print(err)
64
- return chatgpt(messages, temperature)
65
-
66
-
67
- def llm(pos_tags, neg_tags, input_txt):
68
- user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
69
- messages = [
70
- {'role': 'system',
71
- 'content': SYSTEM_PROMPT},
72
- {'role': 'user',
73
- 'content': user}]
74
- response = chatgpt(messages)
75
- print(f'USER:\n\n{user}')
76
- print(f'RESPONSE:\n\n{response}')
77
- return response
78
-
79
-
80
- def postprocess(input_txt, output_txt_list):
81
- speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
82
- output_txt = []
83
- for txt in output_txt_list:
84
- if len(speaker_pattern.findall(txt)) > 0:
85
- output_txt.append(txt)
86
- output_txt = ''.join(output_txt)
87
- speakers = set(speaker_pattern.findall(input_txt))
88
- output_txt = speaker_pattern.split(output_txt)
89
-
90
- results = []
91
- for idx, txt in enumerate(output_txt):
92
- if txt.startswith('说话人'):
93
- if txt not in speakers:
94
- continue
95
- if idx < len(output_txt) - 1:
96
- if not output_txt[idx + 1].startswith('说话人'):
97
- res = txt + output_txt[idx + 1]
98
- else:
99
- res = txt
100
- results.append(res.strip())
101
- return '\n'.join(results)
102
-
103
-
104
- def filter(api_key, pos_tags, neg_tags, input_txt):
105
- if api_key is None or api_key == '':
106
- return 'OPENAI API Key is not set.'
107
- openai.api_key = api_key
108
- input_txt_list = preprocess(input_txt)
109
- output_txt_list = []
110
- for txt in input_txt_list:
111
- output_txt = llm(pos_tags, neg_tags, txt)
112
- output_txt_list.append(output_txt)
113
- output_txt = postprocess(input_txt, output_txt_list)
114
- return output_txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  if __name__ == '__main__':
118
  parser = argparse.ArgumentParser()
 
 
 
 
119
  args = parser.parse_args()
120
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Blocks() as demo:
122
  with gr.Row():
123
  with gr.Column(scale=0.3):
@@ -128,21 +193,39 @@ if __name__ == '__main__':
128
  elem_id='api_key_textbox',
129
  placeholder='Enter your OPENAI API Key')
130
  with gr.Row():
131
- pos_txt = gr.Textbox(
132
- lines=2,
133
- label='Positive Tags',
134
- elem_id='pos_textbox',
135
- placeholder='Enter positive tags split by semicolon')
 
 
 
 
 
 
 
 
136
  with gr.Row():
137
- neg_txt = gr.Textbox(
138
- lines=2,
139
- visible=False,
140
- label='Negative Tags',
141
- elem_id='neg_textbox',
142
- placeholder='Enter negative tags split by semicolon')
 
 
 
 
 
 
 
 
 
143
  with gr.Row():
144
  input_txt = gr.Textbox(
145
  lines=4,
 
146
  label='Input',
147
  elem_id='input_textbox',
148
  placeholder='Enter text and press submit')
@@ -152,17 +235,24 @@ if __name__ == '__main__':
152
  clear = gr.Button('Clear')
153
  with gr.Column(scale=0.7):
154
  output_txt = gr.Textbox(
 
155
  label='Output',
156
  elem_id='output_textbox')
157
- output_txt = output_txt.style(height=690)
 
 
 
 
 
158
  submit.click(
159
- filter,
160
- [api_key, pos_txt, neg_txt, input_txt],
 
161
  [output_txt])
162
  clear.click(
163
  lambda: ['', '', ''],
164
  None,
165
- pos_txt, neg_txt, input_txt)
166
 
167
  demo.queue(concurrency_count=6)
168
  demo.launch()
 
1
  import re
2
+ import json
3
  import argparse
4
 
5
  import openai
6
  import gradio as gr
7
+ from functools import partial
8
 
9
 
10
  SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
 
13
  {input_txt}
14
 
15
  Please select the rounds containing one of following tags: {pos_tags}.
16
+ Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
17
+
18
+
19
+ class GPT4News():
20
+
21
+ def __init__(self, prompt_formats):
22
+ self.name2prompt = {x['name']: x for x in prompt_formats}
23
+
24
+ def preprocess(self, function_name, input_txt):
25
+ max_length = self.name2prompt[function_name]['split_length']
26
+ max_convs = self.name2prompt[function_name]['split_round']
27
+
28
+ input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt)
29
+ speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
30
+ input_txt = speaker_pattern.split(input_txt)
31
+ input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
32
+
33
+ conversations = []
34
+ for idx, txt in enumerate(input_txt):
35
+ if speaker_pattern.match(txt):
36
+ if idx < len(input_txt) - 1:
37
+ if not speaker_pattern.match(input_txt[idx + 1]):
38
+ conv = [txt, input_txt[idx + 1]]
39
+ else:
40
+ conv = [txt, '']
41
+ while len(''.join(conv)) > max_length:
42
+ pruned_len = max_length - len(''.join(conv[0]))
43
+ pruned_conv = [txt, conv[1][:pruned_len]]
44
+ conversations.append(pruned_conv)
45
+ conv = [txt, conv[-1][pruned_len:]]
46
+ conversations.append(conv)
47
+
48
+ input_txt_list = ['']
49
+ for conv in conversations:
50
+ conv_length = len(''.join(conv))
51
+ if len(input_txt_list[-1]) + conv_length >= max_length:
52
+ input_txt_list.append('')
53
+ elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
54
+ input_txt_list.append('')
55
+ input_txt_list[-1] += ''.join(conv)
56
+
57
+ processed_txt_list = []
58
+ for input_txt in input_txt_list:
59
+ input_txt = ''.join(input_txt)
60
+ input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
61
+ processed_txt_list.append(input_txt.strip())
62
+ return processed_txt_list
63
+
64
+ def chatgpt(self, messages, temperature=0.0):
65
+ try:
66
+ completion = openai.ChatCompletion.create(
67
+ model="gpt-3.5-turbo",
68
+ messages=messages,
69
+ temperature=temperature
70
+ )
71
+ return completion.choices[0].message.content
72
+ except Exception as err:
73
+ print(err)
74
+ return self.chatgpt(messages, temperature)
75
+
76
+ def llm(self, function_name, temperature, **kwargs):
77
+ prompt = self.name2prompt[function_name]
78
+ user_kwargs = {key: kwargs[key] for key in prompt['user_keys']}
79
+ user = prompt['user'].format(**user_kwargs)
80
+ system_kwargs = {key: kwargs[key] for key in prompt['system_keys']}
81
+ system = prompt['system'].format(**system_kwargs)
82
+ messages = [
83
+ {'role': 'system',
84
+ 'content': system},
85
+ {'role': 'user',
86
+ 'content': user}]
87
+ response = self.chatgpt(messages, temperature=temperature)
88
+ print(f'SYSTEM:\n\n{system}')
89
+ print(f'USER:\n\n{user}')
90
+ print(f'RESPONSE:\n\n{response}')
91
+ return response
92
+
93
+ def translate(self, txt, output_lang):
94
+ if output_lang == 'English':
95
+ return txt
96
+ system = 'Translate the following text into {}:\n\n{}'.format(
97
+ output_lang, txt)
98
+ messages = [{'role': 'system', 'content': system}]
99
+ response = self.chatgpt(messages)
100
+ print(f'SYSTEM:\n\n{system}')
101
+ print(f'RESPONSE:\n\n{response}')
102
+ return response
103
+
104
+ def postprocess(self, function_name, input_txt, output_txt_list,
105
+ output_lang):
106
+ if not self.name2prompt[function_name]['post_filter']:
107
+ output_txt = '\n\n'.join(output_txt_list)
108
+ output_txt = self.translate(output_txt, output_lang)
109
+ return output_txt
110
+
111
+ speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)')
112
+ output_txt = []
113
+ for txt in output_txt_list:
114
+ if len(speaker_pattern.findall(txt)) > 0:
115
+ output_txt.append(txt)
116
+ output_txt = ''.join(output_txt)
117
+ speakers = set(speaker_pattern.findall(input_txt))
118
+ output_txt = speaker_pattern.split(output_txt)
119
+
120
+ results = []
121
+ for idx, txt in enumerate(output_txt):
122
+ if speaker_pattern.match(txt):
123
+ if txt not in speakers:
124
+ continue
125
+ if idx < len(output_txt) - 1:
126
+ if not speaker_pattern.match(output_txt[idx + 1]):
127
+ res = txt + output_txt[idx + 1]
128
+ else:
129
+ res = txt
130
+ res = self.translate(res, output_lang)
131
+ results.append(res.strip())
132
+ return '\n\n'.join(results)
133
+
134
+ def __call__(self, api_key, function_name, temperature, output_lang,
135
+ input_txt, tags):
136
+ if api_key is None or api_key == '':
137
+ return 'OPENAI API Key is not set.'
138
+ if function_name is None or function_name == '':
139
+ return 'Function is not selected.'
140
+ openai.api_key = api_key
141
+ input_txt_list = self.preprocess(function_name, input_txt)
142
+ input_txt = '\n'.join(input_txt_list)
143
+ output_txt_list = []
144
+ for txt in input_txt_list:
145
+ llm_kwargs = dict(input_txt=txt,
146
+ tags=tags)
147
+ output_txt = self.llm(function_name, temperature, **llm_kwargs)
148
+ output_txt_list.append(output_txt)
149
+ output_txt = self.postprocess(
150
+ function_name, input_txt, output_txt_list, output_lang)
151
+ return output_txt
152
+
153
+ @property
154
+ def function_names(self):
155
+ return self.name2prompt.keys()
156
+
157
+
158
+ def function_name_select_callback(componments, name2prompt, function_name):
159
+ prompt = name2prompt[function_name]
160
+ user_keys = prompt['user_keys']
161
+ result = []
162
+ for comp in componments:
163
+ result.append(gr.update(visible=comp in user_keys))
164
+ return result
165
 
166
 
167
  if __name__ == '__main__':
168
  parser = argparse.ArgumentParser()
169
+ parser.add_argument('--prompt', type=str, default='prompts/interview.json',
170
+ help='path to the prompt file')
171
+ parser.add_argument('--temperature', type=float, default='0.7',
172
+ help='temperature for the llm model')
173
  args = parser.parse_args()
174
 
175
+ prompt_formats = json.load(open(args.prompt, 'r'))
176
+ gpt4news = GPT4News(prompt_formats)
177
+
178
+ languages = ['Arabic', 'Bengali', 'Chinese (Simplified)',
179
+ 'Chinese (Traditional)', 'Dutch', 'English', 'French',
180
+ 'German', 'Hindi', 'Italian', 'Japanese', 'Korean',
181
+ 'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish',
182
+ 'Urdu']
183
+ default_func = sorted(gpt4news.function_names)[0]
184
+ default_user_keys = gpt4news.name2prompt[default_func]['user_keys']
185
+
186
  with gr.Blocks() as demo:
187
  with gr.Row():
188
  with gr.Column(scale=0.3):
 
193
  elem_id='api_key_textbox',
194
  placeholder='Enter your OPENAI API Key')
195
  with gr.Row():
196
+ function_name = gr.Dropdown(
197
+ sorted(gpt4news.function_names),
198
+ value=default_func,
199
+ elem_id='function_dropdown',
200
+ label='Function',
201
+ info='choose a function to run')
202
+ with gr.Row():
203
+ output_lang = gr.Dropdown(
204
+ languages,
205
+ value='English',
206
+ elem_id='output_lang_dropdown',
207
+ label='Output Language',
208
+ info='choose a language to output')
209
  with gr.Row():
210
+ temperature = gr.Slider(
211
+ minimum=0.0,
212
+ maximum=1.0,
213
+ value=args.temperature,
214
+ step=0.1,
215
+ interactive=True,
216
+ label='Temperature',
217
+ info='higher temperature means more creative')
218
+ with gr.Row():
219
+ tags = gr.Textbox(
220
+ lines=1,
221
+ visible='tags' in default_user_keys,
222
+ label='Tags',
223
+ elem_id='tags_textbox',
224
+ placeholder='Enter tags split by semicolon')
225
  with gr.Row():
226
  input_txt = gr.Textbox(
227
  lines=4,
228
+ visible='input_txt' in default_user_keys,
229
  label='Input',
230
  elem_id='input_textbox',
231
  placeholder='Enter text and press submit')
 
235
  clear = gr.Button('Clear')
236
  with gr.Column(scale=0.7):
237
  output_txt = gr.Textbox(
238
+ lines=8,
239
  label='Output',
240
  elem_id='output_textbox')
241
+ function_name.select(
242
+ partial(function_name_select_callback, ['input_txt', 'tags'],
243
+ gpt4news.name2prompt),
244
+ [function_name],
245
+ [input_txt, tags]
246
+ )
247
  submit.click(
248
+ gpt4news,
249
+ [api_key, function_name, temperature, output_lang,
250
+ input_txt, tags],
251
  [output_txt])
252
  clear.click(
253
  lambda: ['', '', ''],
254
  None,
255
+ tags, input_txt)
256
 
257
  demo.queue(concurrency_count=6)
258
  demo.launch()
prompts/interview.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "name": "searching",
4
+ "system": "You are a tool for filtering out paragraphs from the interview dialogues given by user.",
5
+ "system_keys": [],
6
+ "user": "Interview Dialogues:\n{input_txt}\n\nPlease select the rounds containing one of following tags: {tags}. Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like \"speaker_name speaking_time: tag, reason\".",
7
+ "user_keys": [
8
+ "input_txt",
9
+ "tags"
10
+ ],
11
+ "post_filter": true,
12
+ "split_length": 4000,
13
+ "split_round": 4
14
+ },
15
+ {
16
+ "name": "proofreading",
17
+ "system": "You are a proofreading tool used to improve the wording, grammar, and logical issues in a given interview record. Note that the output should maintain the original meaning, as well as keeping the speaker's name and interview time unchanged.",
18
+ "system_keys": [],
19
+ "user": "{input_txt}\n\n------\nPlease proofread the interview record and output the improved version. Note that the output should maintain the original meaning, as well as keeping the speaker's name and interview time unchanged.",
20
+ "user_keys": [
21
+ "input_txt"
22
+ ],
23
+ "post_filter": true,
24
+ "split_length": 4000,
25
+ "split_round": 4
26
+ },
27
+ {
28
+ "name": "summarization",
29
+ "system": "You are a text summarization tool used to summarize the meaning of each round of conversation in an interview record.",
30
+ "system_keys": [],
31
+ "user": "{input_txt}\n\n------\nPlease summarize the meaning of each round of conversation in an interview record. Note that the output should be concise and contains key information. The output should be like \"speaker_name speaking_time: summarization\"",
32
+ "user_keys": [
33
+ "input_txt"
34
+ ],
35
+ "post_filter": true,
36
+ "split_length": 4000,
37
+ "split_round": 4
38
+ },
39
+ {
40
+ "name": "summary to news",
41
+ "system": "You are a news writer who writes news articles based on the given summary of interview records.",
42
+ "system_keys": [],
43
+ "user": "{input_txt}\n\n------\nPlease write a news article based on the given summary of interview records.",
44
+ "user_keys": [
45
+ "input_txt"
46
+ ],
47
+ "post_filter": false,
48
+ "split_length": 10000000,
49
+ "split_round": 10000
50
+ },
51
+ {
52
+ "name": "summary to twitter",
53
+ "system": "You are a Twitter author who writes tweets based on the given summary of interview records.",
54
+ "system_keys": [],
55
+ "user": "{input_txt}\n\n------\nPlease writes a tweet based on the given summary of interview records. Note that the number of words in the output MUST be less than 140.",
56
+ "user_keys": [
57
+ "input_txt"
58
+ ],
59
+ "post_filter": false,
60
+ "split_length": 10000000,
61
+ "split_round": 10000
62
+ },
63
+ {
64
+ "name": "summary to weibo",
65
+ "system": "You are a Weibo author who writes eye-catching short articles based on the given summary of interview records.",
66
+ "system_keys": [],
67
+ "user": "{input_txt}\n\n------\nPlease write an eye-catching short article based on the given summary of interview records. Note that the number of words in the output MUST be less than 140.",
68
+ "user_keys": [
69
+ "input_txt"
70
+ ],
71
+ "post_filter": false,
72
+ "split_length": 10000000,
73
+ "split_round": 10000
74
+ }
75
+ ]