import transformers import gradio as gr # import warnings import torch # warnings.simplefilter('ignore') device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') #add padding token, beginstring and endstring tokens tokenizer.add_special_tokens( { "pad_token":"", "bos_token":"", "eos_token":"" }) #add bot token since it is not a special token tokenizer.add_tokens([":"]) print("=====Done 1") model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') model.resize_token_embeddings(len(tokenizer)) model.load_state_dict(torch.load('./gpt2talk.pt', map_location=torch.device('cpu'))) print("=====Done 2") model.eval() def inference(quiz): quiz1 = quiz quiz = ""+quiz+" :" quiztoken = tokenizer(quiz, return_tensors='pt' ) answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0] answer = tokenizer.decode(answer, skip_special_tokens=True) answer = answer.replace(" :","").replace(quiz1,"") + '.' return answer def chatbot(input_text): response = inference(input_text) return response # Create the Gradio interface print("=====Done 3") gr.Interface( fn=chatbot, inputs='text', outputs='text', live=False, #set false to avoid caching # interpretation="chat", title="ChatFinance", description="Ask the a question and see its response!", ).launch() # print("=====Done 4") # # Launch the Gradio interface # iface.launch()