Spaces:
Runtime error
Runtime error
File size: 4,502 Bytes
6bc94ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from celebbot import CelebBot
import streamlit as st
import re
import spacy
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from utils import *
@st.cache_resource
def get_seq2seq_model(model_id):
return AutoModelForSeq2SeqLM.from_pretrained(model_id)
@st.cache_resource
def get_auto_model(model_id):
return AutoModel.from_pretrained(model_id)
@st.cache_resource
def get_tokenizer(model_id):
return AutoTokenizer.from_pretrained(model_id)
@st.cache_data
def get_celeb_data(fpath):
with open(fpath) as json_file:
return json.load(json_file)
@st.cache_resource
def preprocess_text(name, gender, text, model_id):
lname = name.split(" ")[-1]
lname_regex = re.compile(rf'\b({lname})\b')
name_regex = re.compile(rf'\b({name})\b')
lnames = lname+"βs" if not lname.endswith("s") else lname+"β"
lnames_regex = re.compile(rf'\b({lnames})\b')
names = name+"βs" if not name.endswith("s") else name+"β"
names_regex = re.compile(rf'\b({names})\b')
if gender == "M":
text = re.sub(he_regex, "I", text)
text = re.sub(his_regex, "my", text)
elif gender == "F":
text = re.sub(she_regex, "I", text)
text = re.sub(her_regex, "my", text)
text = re.sub(names_regex, "my", text)
text = re.sub(lnames_regex, "my", text)
text = re.sub(name_regex, "I", text)
text = re.sub(lname_regex, "I", text)
spacy_model = spacy.load(model_id)
texts = [i.text.strip() for i in spacy_model(text).sents]
return spacy_model, texts
def main():
hide_footer()
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "QA_model_path" not in st.session_state:
st.session_state["QA_model_path"] = "google/flan-t5-base"
if "sentTr_model_path" not in st.session_state:
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
if "start_chat" not in st.session_state:
st.session_state["start_chat"] = False
model_list = ["base", "large", "xl", "xxl"]
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
celeb_data = get_celeb_data(f'data.json')
# Create a Form Component on the Sidebar for accepting input data and parameters
celeb_name = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
celeb_gender = celeb_data[celeb_name]["gender"]
knowledge = celeb_data[celeb_name]["knowledge"]
model_choice = st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
st.session_state["QA_model_path"] = f"google/flan-t5-{model_choice}"
# submitted = st.form_submit_button(label="Start Chatting")
# if submitted:
# st.session_state["start_chat"] = True
# if st.session_state["start_chat"]:
celeb_bot = CelebBot(celeb_name,
get_tokenizer(st.session_state["QA_model_path"]),
get_seq2seq_model(st.session_state["QA_model_path"]),
get_tokenizer(st.session_state["sentTr_model_path"]),
get_auto_model(st.session_state["sentTr_model_path"]),
*preprocess_text(celeb_name, celeb_gender, knowledge, "en_core_web_sm")
)
prompt = st.chat_input("Say something")
print(prompt)
if prompt:
celeb_bot.text = prompt
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
st.session_state["messages"].append({"role": "user", "content": prompt})
# Add assistant response to chat history
response = celeb_bot.question_answer()
# disable autoplay to play in HTML
b64 = celeb_bot.text_to_speech(autoplay=False)
md = f"""
<p>{response}</p>
<audio controls autoplay style="display:none;">
<source src="data:audio/wav;base64,{b64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
"""
st.chat_message("assistant").markdown(
md,
unsafe_allow_html=True,
)
# Display assistant response in chat message container
st.session_state["messages"].append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()
|