Spaces:
Runtime error
Runtime error
from celebbot import CelebBot | |
import streamlit as st | |
import re | |
import spacy | |
import json | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel | |
from utils import * | |
def get_seq2seq_model(model_id): | |
return AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
def get_auto_model(model_id): | |
return AutoModel.from_pretrained(model_id) | |
def get_tokenizer(model_id): | |
return AutoTokenizer.from_pretrained(model_id) | |
def get_celeb_data(fpath): | |
with open(fpath) as json_file: | |
return json.load(json_file) | |
def preprocess_text(name, gender, text, model_id): | |
lname = name.split(" ")[-1] | |
lname_regex = re.compile(rf'\b({lname})\b') | |
name_regex = re.compile(rf'\b({name})\b') | |
lnames = lname+"βs" if not lname.endswith("s") else lname+"β" | |
lnames_regex = re.compile(rf'\b({lnames})\b') | |
names = name+"βs" if not name.endswith("s") else name+"β" | |
names_regex = re.compile(rf'\b({names})\b') | |
if gender == "M": | |
text = re.sub(he_regex, "I", text) | |
text = re.sub(his_regex, "my", text) | |
elif gender == "F": | |
text = re.sub(she_regex, "I", text) | |
text = re.sub(her_regex, "my", text) | |
text = re.sub(names_regex, "my", text) | |
text = re.sub(lnames_regex, "my", text) | |
text = re.sub(name_regex, "I", text) | |
text = re.sub(lname_regex, "I", text) | |
spacy_model = spacy.load(model_id) | |
texts = [i.text.strip() for i in spacy_model(text).sents] | |
return spacy_model, texts | |
def main(): | |
hide_footer() | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "QA_model_path" not in st.session_state: | |
st.session_state["QA_model_path"] = "google/flan-t5-base" | |
if "sentTr_model_path" not in st.session_state: | |
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2" | |
if "start_chat" not in st.session_state: | |
st.session_state["start_chat"] = False | |
model_list = ["base", "large", "xl", "xxl"] | |
for message in st.session_state["messages"]: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
celeb_data = get_celeb_data(f'data.json') | |
# Create a Form Component on the Sidebar for accepting input data and parameters | |
celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys())) | |
celeb_gender = celeb_data[celeb_name]["gender"] | |
knowledge = celeb_data[celeb_name]["knowledge"] | |
model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list) | |
st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}" | |
# submitted = st.form_submit_button(label="Start Chatting") | |
# if submitted: | |
# st.session_state["start_chat"] = True | |
# if st.session_state["start_chat"]: | |
celeb_bot = CelebBot(celeb_name, | |
get_tokenizer(st.session_state["QA_model_path"]), | |
get_seq2seq_model(st.session_state["QA_model_path"]), | |
get_tokenizer(st.session_state["sentTr_model_path"]), | |
get_auto_model(st.session_state["sentTr_model_path"]), | |
*preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm") | |
) | |
prompt = st.chat_input("Say something") | |
print(prompt) | |
if prompt: | |
celeb_bot.text = prompt | |
# Display user message in chat message container | |
st.chat_message("user").markdown(prompt) | |
# Add user message to chat history | |
st.session_state["messages"].append({"role": "user", "content": prompt}) | |
# Add assistant response to chat history | |
response = celeb_bot.question_answer() | |
# disable autoplay to play in HTML | |
b64 = celeb_bot.text_to_speech(autoplay=False) | |
md = f""" | |
<p>{response}</p> | |
<audio controls autoplay style="display:none;"> | |
<source src="data:audio/wav;base64,{b64}" type="audio/wav"> | |
Your browser does not support the audio element. | |
</audio> | |
""" | |
st.chat_message("assistant").markdown( | |
md, | |
unsafe_allow_html=True, | |
) | |
# Display assistant response in chat message container | |
st.session_state["messages"].append({"role": "assistant", "content": response}) | |
if __name__ == "__main__": | |
main() | |