import os, time, transformers import streamlit as st from model import MRCQuestionAnswering from infer import * from gg_search import GoogleSearch, getContent ggsearch = GoogleSearch() from relevance_ranking import rel_ranking class Chatbot(): def __init__(self): st.header('🦜 Question answering') st.warning("Warning: the processing may take long cause I have no any GPU now...") st.info("This app uses google search engine for each input question...") st.info("Type 'clear' to delete chat history...") st.info("About me: namnh113") model_checkpoint = 'namnh113/vi-mrc-large' self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint) self.model = MRCQuestionAnswering.from_pretrained(model_checkpoint) self.API_KEY = st.sidebar.text_input( 'API key', type='password', help="Type in your HuggingFace API key to use this app") self.checkpoint = st.sidebar.selectbox( label = "Choose model", options = [model_checkpoint], help="List available model to predict" ) def generate_response(self, input_text): try: links, documents = ggsearch.search(question) if not documents: try: for url in links: docs = getContent(url) if len(docs) > 20 and 'The security system for this website has been triggered. Completing the challenge below verifies you are a human and gives you access.' not in doc: documents += [docs] except: pass except: pass passages = rel_ranking(question, documents) # get top 40 relevant passages passages = '. '.join([p.replace('\n',', ') for p in passages[:40]]) QA_input = { 'question': input_text, 'context': passages } if len(QA_input['question'].strip()) > 0: start = time.time() inputs = [tokenize_function(QA_input, tokenizer)] inputs_ids = data_collator(inputs, tokenizer) outputs = model(**inputs_ids) answer = extract_answer(inputs, outputs, tokenizer)[0] during = time.time() - start print("answer: {}. \nScore start: {}, Score end: {}, Time: {}".format(answer['answer'], answer['score_start'], answer['score_end'], during)) answer = ' '.join([_.strip() for _ in answer['answer'].split()]) return answer if answer else 'No answer found !' def form_data(self): # with st.form('my_form'): try: if not self.API_KEY.startswith('hf_'): st.warning('Please enter your API key!', icon='⚠') if "messages" not in st.session_state: st.session_state.messages = [] st.write(f"You are using {self.checkpoint} model") for message in st.session_state.messages: with st.chat_message(message.get('role')): st.write(message.get("content")) text = st.chat_input(disabled=False) if text: st.session_state.messages.append( { "role":"user", "content": text } ) with st.chat_message("user"): st.write(text) if text.lower() == "clear": del st.session_state.messages return result = self.generate_response(text) st.session_state.messages.append( { "role": "assistant", "content": result } ) with st.chat_message('assistant'): st.markdown(result) except Exception as e: st.error(e, icon="🚨") model = LLM_Langchain() model.form_data()