Spaces:
Runtime error
Runtime error
stevengrove
commited on
Commit
•
65cfc9d
1
Parent(s):
9d77f5e
inital commit
Browse files
app.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
|
9 |
+
OPENAI_KEY = os.getenv('OPENAI_KEY')
|
10 |
+
|
11 |
+
|
12 |
+
SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
|
13 |
+
|
14 |
+
USER_FORMAT = """Interview Dialogues:
|
15 |
+
{input_txt}
|
16 |
+
|
17 |
+
Please select the rounds containing one of following tags: {pos_tags}.
|
18 |
+
Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
|
19 |
+
|
20 |
+
|
21 |
+
def preprocess(input_txt, max_length=4000, max_convs=4):
|
22 |
+
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
|
23 |
+
input_txt = speaker_pattern.split(input_txt)
|
24 |
+
input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
|
25 |
+
|
26 |
+
conversations = []
|
27 |
+
for idx, txt in enumerate(input_txt):
|
28 |
+
if txt.startswith('说话人'):
|
29 |
+
if idx < len(input_txt) - 1:
|
30 |
+
if not input_txt[idx + 1].startswith('说话人'):
|
31 |
+
conv = [txt, input_txt[idx + 1]]
|
32 |
+
else:
|
33 |
+
conv = [txt, '']
|
34 |
+
while len(''.join(conv)) > max_length:
|
35 |
+
pruned_len = max_length - len(''.join(conv[0]))
|
36 |
+
pruned_conv = [txt, conv[1][:pruned_len]]
|
37 |
+
conversations.append(pruned_conv)
|
38 |
+
conv = [txt, conv[-1][pruned_len:]]
|
39 |
+
conversations.append(conv)
|
40 |
+
|
41 |
+
input_txt_list = ['']
|
42 |
+
for conv in conversations:
|
43 |
+
conv_length = len(''.join(conv))
|
44 |
+
if len(input_txt_list[-1]) + conv_length >= max_length:
|
45 |
+
input_txt_list.append('')
|
46 |
+
elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
|
47 |
+
input_txt_list.append('')
|
48 |
+
input_txt_list[-1] += ''.join(conv)
|
49 |
+
|
50 |
+
processed_txt_list = []
|
51 |
+
for input_txt in input_txt_list:
|
52 |
+
input_txt = ''.join(input_txt)
|
53 |
+
input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
|
54 |
+
processed_txt_list.append(input_txt.strip())
|
55 |
+
return processed_txt_list
|
56 |
+
|
57 |
+
|
58 |
+
def chatgpt(messages, temperature=0.0):
|
59 |
+
try:
|
60 |
+
completion = openai.ChatCompletion.create(
|
61 |
+
model="gpt-3.5-turbo",
|
62 |
+
messages=messages,
|
63 |
+
temperature=temperature
|
64 |
+
)
|
65 |
+
return completion.choices[0].message.content
|
66 |
+
except Exception as err:
|
67 |
+
print(err)
|
68 |
+
return chatgpt(messages, temperature)
|
69 |
+
|
70 |
+
|
71 |
+
def llm(pos_tags, neg_tags, input_txt):
|
72 |
+
user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
|
73 |
+
messages = [
|
74 |
+
{'role': 'system',
|
75 |
+
'content': SYSTEM_PROMPT},
|
76 |
+
{'role': 'user',
|
77 |
+
'content': user}]
|
78 |
+
response = chatgpt(messages)
|
79 |
+
print(f'USER:\n\n{user}')
|
80 |
+
print(f'RESPONSE:\n\n{response}')
|
81 |
+
return response
|
82 |
+
|
83 |
+
|
84 |
+
def postprocess(input_txt, output_txt_list):
|
85 |
+
speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
|
86 |
+
output_txt = []
|
87 |
+
for txt in output_txt_list:
|
88 |
+
if len(speaker_pattern.findall(txt)) > 0:
|
89 |
+
output_txt.append(txt)
|
90 |
+
output_txt = ''.join(output_txt)
|
91 |
+
speakers = set(speaker_pattern.findall(input_txt))
|
92 |
+
output_txt = speaker_pattern.split(output_txt)
|
93 |
+
|
94 |
+
results = []
|
95 |
+
for idx, txt in enumerate(output_txt):
|
96 |
+
if txt.startswith('说话人'):
|
97 |
+
if txt not in speakers:
|
98 |
+
continue
|
99 |
+
if idx < len(output_txt) - 1:
|
100 |
+
if not output_txt[idx + 1].startswith('说话人'):
|
101 |
+
res = txt + output_txt[idx + 1]
|
102 |
+
else:
|
103 |
+
res = txt
|
104 |
+
results.append(res.strip())
|
105 |
+
return '\n'.join(results)
|
106 |
+
|
107 |
+
|
108 |
+
def filter(pos_tags, neg_tags, input_txt):
|
109 |
+
input_txt_list = preprocess(input_txt)
|
110 |
+
output_txt_list = []
|
111 |
+
for txt in input_txt_list:
|
112 |
+
output_txt = llm(pos_tags, neg_tags, txt)
|
113 |
+
output_txt_list.append(output_txt)
|
114 |
+
output_txt = postprocess(input_txt, output_txt_list)
|
115 |
+
return output_txt
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
args = parser.parse_args()
|
121 |
+
|
122 |
+
with gr.Blocks() as demo:
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column(scale=0.3):
|
125 |
+
with gr.Row():
|
126 |
+
pos_txt = gr.Textbox(
|
127 |
+
lines=2,
|
128 |
+
label='Postive Tags',
|
129 |
+
elem_id='pos_textbox',
|
130 |
+
placeholder='Enter positive tags split by semicolon')
|
131 |
+
with gr.Row():
|
132 |
+
neg_txt = gr.Textbox(
|
133 |
+
lines=2,
|
134 |
+
visible=False,
|
135 |
+
label='Negative Tags',
|
136 |
+
elem_id='neg_textbox',
|
137 |
+
placeholder='Enter negative tags split by semicolon')
|
138 |
+
with gr.Row():
|
139 |
+
input_txt = gr.Textbox(
|
140 |
+
lines=5,
|
141 |
+
label='Input',
|
142 |
+
elem_id='input_textbox',
|
143 |
+
placeholder='Enter text and press submit')
|
144 |
+
with gr.Row():
|
145 |
+
submit = gr.Button('Submit')
|
146 |
+
with gr.Row():
|
147 |
+
clear = gr.Button('Clear')
|
148 |
+
with gr.Column(scale=0.7):
|
149 |
+
output_txt = gr.Textbox(
|
150 |
+
label='Output',
|
151 |
+
elem_id='output_textbox')
|
152 |
+
output_txt = output_txt.style(height=690)
|
153 |
+
submit.click(
|
154 |
+
filter,
|
155 |
+
[pos_txt, neg_txt, input_txt],
|
156 |
+
[output_txt])
|
157 |
+
clear.click(
|
158 |
+
lambda: ['', '', ''],
|
159 |
+
None,
|
160 |
+
pos_txt, neg_txt, input_txt)
|
161 |
+
|
162 |
+
demo.queue(concurrency_count=6)
|
163 |
+
demo.launch()
|