import re import json import argparse import openai import gradio as gr from functools import partial class GPT4News(): def __init__(self, prompt_formats): self.name2prompt = {x['name']: x for x in prompt_formats} def preprocess(self, function_name, input_txt): max_length = self.name2prompt[function_name]['split_length'] max_convs = self.name2prompt[function_name]['split_round'] input_txt = re.sub(r'(说话人)(\d+ \d\d:\d\d)', r'Speaker \2', input_txt) speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') input_txt = speaker_pattern.split(input_txt) input_txt = [x.strip().replace('\n', ' ') for x in input_txt] conversations = [] for idx, txt in enumerate(input_txt): if speaker_pattern.match(txt): if idx < len(input_txt) - 1: if not speaker_pattern.match(input_txt[idx + 1]): conv = [txt, input_txt[idx + 1]] else: conv = [txt, ''] while len(''.join(conv)) > max_length: pruned_len = max_length - len(''.join(conv[0])) pruned_conv = [txt, conv[1][:pruned_len]] conversations.append(pruned_conv) conv = [txt, conv[-1][pruned_len:]] conversations.append(conv) input_txt_list = [''] for conv in conversations: conv_length = len(''.join(conv)) if len(input_txt_list[-1]) + conv_length >= max_length: input_txt_list.append('') elif len(speaker_pattern.findall(input_txt_list[-1])) >= max_convs: input_txt_list.append('') input_txt_list[-1] += ''.join(conv) processed_txt_list = [] for input_txt in input_txt_list: input_txt = ''.join(input_txt) input_txt = speaker_pattern.sub(r'\n\1: ', input_txt) processed_txt_list.append(input_txt.strip()) return processed_txt_list def chatgpt(self, messages, temperature=0.0): try: completion = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=messages, temperature=temperature ) return completion.choices[0].message.content except Exception as err: print(err) return self.chatgpt(messages, temperature) def llm(self, function_name, temperature, **kwargs): prompt = self.name2prompt[function_name] user_kwargs = {key: kwargs[key] for key in prompt['user_keys']} user = prompt['user'].format(**user_kwargs) system_kwargs = {key: kwargs[key] for key in prompt['system_keys']} system = prompt['system'].format(**system_kwargs) messages = [ {'role': 'system', 'content': system}, {'role': 'user', 'content': user}] response = self.chatgpt(messages, temperature=temperature) print(f'SYSTEM:\n\n{system}') print(f'USER:\n\n{user}') print(f'RESPONSE:\n\n{response}') return response def translate(self, txt, output_lang): if output_lang == 'English': return txt system = 'You are a translator.' user = 'Translate the following text to {}:\n\n{}'.format( output_lang, txt) messages = [{'role': 'system', 'content': system, 'role': 'user', 'content': user}] response = self.chatgpt(messages) print(f'SYSTEM:\n\n{system}') print(f'USER:\n\n{user}') print(f'RESPONSE:\n\n{response}') return response def postprocess(self, function_name, input_txt, output_txt_list, output_lang): if not self.name2prompt[function_name]['post_filter']: output_txt = '\n\n'.join(output_txt_list) output_txt = self.translate(output_txt, output_lang) return output_txt speaker_pattern = re.compile(r'(Speaker \d+ \d\d:\d\d)') output_txt = [] for txt in output_txt_list: if len(speaker_pattern.findall(txt)) > 0: output_txt.append(txt) output_txt = ''.join(output_txt) speakers = set(speaker_pattern.findall(input_txt)) output_txt = speaker_pattern.split(output_txt) results = [] for idx, txt in enumerate(output_txt): if speaker_pattern.match(txt): if txt not in speakers: continue if idx < len(output_txt) - 1: if not speaker_pattern.match(output_txt[idx + 1]): res = txt + output_txt[idx + 1] else: res = txt res = self.translate(res, output_lang) results.append(res.strip()) return '\n\n'.join(results) def __call__(self, api_key, function_name, temperature, output_lang, input_txt, tags): if api_key is None or api_key == '': return 'OPENAI API Key is not set.' if function_name is None or function_name == '': return 'Function is not selected.' openai.api_key = api_key input_txt_list = self.preprocess(function_name, input_txt) input_txt = '\n'.join(input_txt_list) output_txt_list = [] for txt in input_txt_list: llm_kwargs = dict(input_txt=txt, tags=tags) output_txt = self.llm(function_name, temperature, **llm_kwargs) output_txt_list.append(output_txt) output_txt = self.postprocess( function_name, input_txt, output_txt_list, output_lang) return output_txt @property def function_names(self): return self.name2prompt.keys() def function_name_select_callback(componments, name2prompt, function_name): prompt = name2prompt[function_name] user_keys = prompt['user_keys'] result = [] for comp in componments: result.append(gr.update(visible=comp in user_keys)) return result if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--prompt', type=str, default='prompts/interview.json', help='path to the prompt file') parser.add_argument('--temperature', type=float, default='0.7', help='temperature for the llm model') args = parser.parse_args() prompt_formats = json.load(open(args.prompt, 'r')) gpt4news = GPT4News(prompt_formats) languages = ['Arabic', 'Bengali', 'Chinese (Simplified)', 'Chinese (Traditional)', 'Dutch', 'English', 'French', 'German', 'Hindi', 'Italian', 'Japanese', 'Korean', 'Portuguese', 'Punjabi', 'Russian', 'Spanish', 'Turkish', 'Urdu'] default_func = sorted(gpt4news.function_names)[0] default_user_keys = gpt4news.name2prompt[default_func]['user_keys'] with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=0.3): with gr.Row(): api_key = gr.Textbox( lines=1, label='OPENAI API Key', elem_id='api_key_textbox', placeholder='Enter your OPENAI API Key') with gr.Row(): function_name = gr.Dropdown( sorted(gpt4news.function_names), value=default_func, elem_id='function_dropdown', label='Function', info='choose a function to run') with gr.Row(): output_lang = gr.Dropdown( languages, value='English', elem_id='output_lang_dropdown', label='Output Language', info='choose a language to output') with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=args.temperature, step=0.1, interactive=True, label='Temperature', info='higher temperature means more creative') with gr.Row(): tags = gr.Textbox( lines=1, visible='tags' in default_user_keys, label='Tags', elem_id='tags_textbox', placeholder='Enter tags split by semicolon') with gr.Row(): input_txt = gr.Textbox( lines=4, visible='input_txt' in default_user_keys, label='Input', elem_id='input_textbox', placeholder='Enter text and press submit') with gr.Row(): submit = gr.Button('Submit') with gr.Row(): clear = gr.Button('Clear') with gr.Column(scale=0.7): output_txt = gr.Textbox( lines=8, label='Output', elem_id='output_textbox') function_name.select( partial(function_name_select_callback, ['input_txt', 'tags'], gpt4news.name2prompt), [function_name], [input_txt, tags] ) submit.click( gpt4news, [api_key, function_name, temperature, output_lang, input_txt, tags], [output_txt]) clear.click( lambda: ['', '', ''], None, tags, input_txt) demo.queue(concurrency_count=6) demo.launch()