Spaces:
Sleeping
Sleeping
import os, time, transformers | |
import streamlit as st | |
from model import MRCQuestionAnswering | |
from relevance_ranking import rel_ranking | |
from huggingface_hub import login | |
from infer import * | |
from gg_search import GoogleSearch, getContent | |
ggsearch = GoogleSearch() | |
class Chatbot(): | |
def __init__(self): | |
st.header('π¦ Question answering') | |
st.warning("Warning: the processing may take long cause I have no any GPU now...") | |
st.info("This app uses google search engine for each input question...") | |
st.info("Type 'clear' to delete chat history...") | |
st.info("About me: namnh113") | |
self.API_KEY = st.sidebar.text_input( | |
'API key (not necessary for now)', | |
type='password', | |
help="Type in your HuggingFace API key to use this app") | |
login(token=os.environ['hf_api_key']) | |
self.model_checkpoint = 'namnh113/vi-mrc-large' | |
self.checkpoint = st.sidebar.selectbox( | |
label = "Choose model", | |
options = [self.model_checkpoint], | |
help="List available model to predict" | |
) | |
def generate_response(self, question): | |
try: | |
links, documents = ggsearch.search(question) | |
if not documents: | |
try: | |
for url in links: | |
docs = getContent(url) | |
if len(docs) > 20 and 'The security system for this website has been triggered. Completing the challenge below verifies you are a human and gives you access.' not in doc: | |
documents += [docs] | |
except: | |
pass | |
except: | |
pass | |
passages = rel_ranking(question, documents) | |
# get top 40 relevant passages | |
passages = '. '.join([p.replace('\n',', ') for p in passages[:40]]) | |
QA_input = { | |
'question': question, | |
'context': passages } | |
if len(QA_input['question'].strip()) > 0: | |
start = time.time() | |
inputs = [tokenize_function(QA_input, tokenizer)] | |
inputs_ids = data_collator(inputs, tokenizer) | |
outputs = model(**inputs_ids) | |
answer = extract_answer(inputs, outputs, tokenizer)[0] | |
during = time.time() - start | |
print("answer: {}. \nScore start: {}, Score end: {}, Time: {}".format(answer['answer'], | |
answer['score_start'], | |
answer['score_end'], during)) | |
answer = ' '.join([_.strip() for _ in answer['answer'].split()]) | |
return answer if answer else 'No answer found !' | |
def form_data(self): | |
# with st.form('my_form'): | |
try: | |
if not self.API_KEY.startswith('hf_'): | |
st.warning('Please enter your API key!', icon='β ') | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.write(f"You are using {self.checkpoint} model") | |
for message in st.session_state.messages: | |
with st.chat_message(message.get('role')): | |
st.write(message.get("content")) | |
text = st.chat_input(disabled=False) | |
if text: | |
st.session_state.messages.append( | |
{ | |
"role":"user", | |
"content": text | |
} | |
) | |
with st.chat_message("user"): | |
st.write(text) | |
if text.lower() == "clear": | |
del st.session_state.messages | |
return | |
result = self.generate_response(text) | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": result | |
} | |
) | |
with st.chat_message('assistant'): | |
st.markdown(result) | |
except Exception as e: | |
st.error(e, icon="π¨") | |
chatbot = Chatbot() | |
tokenizer = transformers.AutoTokenizer.from_pretrained(chatbot.model_checkpoint) | |
model = MRCQuestionAnswering.from_pretrained(chatbot.model_checkpoint) | |
chatbot.form_data() |