stevengrove commited on
Commit
65cfc9d
1 Parent(s): 9d77f5e

inital commit

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import argparse
4
+
5
+ import openai
6
+ import gradio as gr
7
+
8
+
9
+ OPENAI_KEY = os.getenv('OPENAI_KEY')
10
+
11
+
12
+ SYSTEM_PROMPT = """You are a tool for filtering out paragraphs from the interview dialogues given by user.""" # noqa: E501
13
+
14
+ USER_FORMAT = """Interview Dialogues:
15
+ {input_txt}
16
+
17
+ Please select the rounds containing one of following tags: {pos_tags}.
18
+ Note that you should ONLY outputs a list of the speaker name, speaking time, tag and reason for each selected round. Do NOT output the content. Each output item should be like "speaker_name speaking_time: tag, reason".""" # noqa: E501
19
+
20
+
21
+ def preprocess(input_txt, max_length=4000, max_convs=4):
22
+ speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
23
+ input_txt = speaker_pattern.split(input_txt)
24
+ input_txt = [x.strip().replace('\n', ' ') for x in input_txt]
25
+
26
+ conversations = []
27
+ for idx, txt in enumerate(input_txt):
28
+ if txt.startswith('说话人'):
29
+ if idx < len(input_txt) - 1:
30
+ if not input_txt[idx + 1].startswith('说话人'):
31
+ conv = [txt, input_txt[idx + 1]]
32
+ else:
33
+ conv = [txt, '']
34
+ while len(''.join(conv)) > max_length:
35
+ pruned_len = max_length - len(''.join(conv[0]))
36
+ pruned_conv = [txt, conv[1][:pruned_len]]
37
+ conversations.append(pruned_conv)
38
+ conv = [txt, conv[-1][pruned_len:]]
39
+ conversations.append(conv)
40
+
41
+ input_txt_list = ['']
42
+ for conv in conversations:
43
+ conv_length = len(''.join(conv))
44
+ if len(input_txt_list[-1]) + conv_length >= max_length:
45
+ input_txt_list.append('')
46
+ elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs:
47
+ input_txt_list.append('')
48
+ input_txt_list[-1] += ''.join(conv)
49
+
50
+ processed_txt_list = []
51
+ for input_txt in input_txt_list:
52
+ input_txt = ''.join(input_txt)
53
+ input_txt = speaker_pattern.sub(r'\n\1: ', input_txt)
54
+ processed_txt_list.append(input_txt.strip())
55
+ return processed_txt_list
56
+
57
+
58
+ def chatgpt(messages, temperature=0.0):
59
+ try:
60
+ completion = openai.ChatCompletion.create(
61
+ model="gpt-3.5-turbo",
62
+ messages=messages,
63
+ temperature=temperature
64
+ )
65
+ return completion.choices[0].message.content
66
+ except Exception as err:
67
+ print(err)
68
+ return chatgpt(messages, temperature)
69
+
70
+
71
+ def llm(pos_tags, neg_tags, input_txt):
72
+ user = USER_FORMAT.format(input_txt=input_txt, pos_tags=pos_tags)
73
+ messages = [
74
+ {'role': 'system',
75
+ 'content': SYSTEM_PROMPT},
76
+ {'role': 'user',
77
+ 'content': user}]
78
+ response = chatgpt(messages)
79
+ print(f'USER:\n\n{user}')
80
+ print(f'RESPONSE:\n\n{response}')
81
+ return response
82
+
83
+
84
+ def postprocess(input_txt, output_txt_list):
85
+ speaker_pattern = re.compile(r'(说话人\d+ \d\d:\d\d)')
86
+ output_txt = []
87
+ for txt in output_txt_list:
88
+ if len(speaker_pattern.findall(txt)) > 0:
89
+ output_txt.append(txt)
90
+ output_txt = ''.join(output_txt)
91
+ speakers = set(speaker_pattern.findall(input_txt))
92
+ output_txt = speaker_pattern.split(output_txt)
93
+
94
+ results = []
95
+ for idx, txt in enumerate(output_txt):
96
+ if txt.startswith('说话人'):
97
+ if txt not in speakers:
98
+ continue
99
+ if idx < len(output_txt) - 1:
100
+ if not output_txt[idx + 1].startswith('说话人'):
101
+ res = txt + output_txt[idx + 1]
102
+ else:
103
+ res = txt
104
+ results.append(res.strip())
105
+ return '\n'.join(results)
106
+
107
+
108
+ def filter(pos_tags, neg_tags, input_txt):
109
+ input_txt_list = preprocess(input_txt)
110
+ output_txt_list = []
111
+ for txt in input_txt_list:
112
+ output_txt = llm(pos_tags, neg_tags, txt)
113
+ output_txt_list.append(output_txt)
114
+ output_txt = postprocess(input_txt, output_txt_list)
115
+ return output_txt
116
+
117
+
118
+ if __name__ == '__main__':
119
+ parser = argparse.ArgumentParser()
120
+ args = parser.parse_args()
121
+
122
+ with gr.Blocks() as demo:
123
+ with gr.Row():
124
+ with gr.Column(scale=0.3):
125
+ with gr.Row():
126
+ pos_txt = gr.Textbox(
127
+ lines=2,
128
+ label='Postive Tags',
129
+ elem_id='pos_textbox',
130
+ placeholder='Enter positive tags split by semicolon')
131
+ with gr.Row():
132
+ neg_txt = gr.Textbox(
133
+ lines=2,
134
+ visible=False,
135
+ label='Negative Tags',
136
+ elem_id='neg_textbox',
137
+ placeholder='Enter negative tags split by semicolon')
138
+ with gr.Row():
139
+ input_txt = gr.Textbox(
140
+ lines=5,
141
+ label='Input',
142
+ elem_id='input_textbox',
143
+ placeholder='Enter text and press submit')
144
+ with gr.Row():
145
+ submit = gr.Button('Submit')
146
+ with gr.Row():
147
+ clear = gr.Button('Clear')
148
+ with gr.Column(scale=0.7):
149
+ output_txt = gr.Textbox(
150
+ label='Output',
151
+ elem_id='output_textbox')
152
+ output_txt = output_txt.style(height=690)
153
+ submit.click(
154
+ filter,
155
+ [pos_txt, neg_txt, input_txt],
156
+ [output_txt])
157
+ clear.click(
158
+ lambda: ['', '', ''],
159
+ None,
160
+ pos_txt, neg_txt, input_txt)
161
+
162
+ demo.queue(concurrency_count=6)
163
+ demo.launch()