Spaces:
Runtime error
Runtime error
from celebbot import CelebBot | |
import streamlit as st | |
import time | |
from streamlit_mic_recorder import speech_to_text | |
from utils import * | |
def main(): | |
hide_footer() | |
model_list = ["flan-t5-xl"] | |
celeb_data = get_celeb_data(f'data.json') | |
st.sidebar.header("CelebChat") | |
expander = st.sidebar.expander('About the app') | |
with expander: | |
st.markdown("This app is a demo of celebrity chatting!") | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [] | |
if "QA_model_path" not in st.session_state: | |
st.session_state["QA_model_path"] = "google/flan-t5-xl" | |
if "sentTr_model_path" not in st.session_state: | |
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2" | |
if "start_chat" not in st.session_state: | |
st.session_state["start_chat"] = False | |
if "prompt_from_audio" not in st.session_state: | |
st.session_state["prompt_from_audio"] = "" | |
if "prompt_from_text" not in st.session_state: | |
st.session_state["prompt_from_text"] = "" | |
if "celeb_bot" not in st.session_state: | |
st.session_state["celeb_bot"] = None | |
def text_submit(): | |
st.session_state["prompt_from_text"] = st.session_state.widget | |
st.session_state.widget = '' | |
st.session_state["celeb_name"] = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys())) | |
model_id=st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list) | |
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id | |
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"] | |
knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"] | |
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], | |
get_tokenizer(st.session_state["QA_model_path"]), | |
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), | |
get_tokenizer(st.session_state["sentTr_model_path"]), | |
get_auto_model(st.session_state["sentTr_model_path"]), | |
*preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm") | |
) | |
dialogue_container = st.container() | |
with dialogue_container: | |
for message in st.session_state["messages"]: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if "_last_audio_id" not in st.session_state: | |
st.session_state["_last_audio_id"] = 0 | |
with st.sidebar: | |
st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT') | |
st.text_input('Or write something', key='widget', on_change=text_submit) | |
if st.session_state["prompt_from_audio"] != None: | |
prompt = st.session_state["prompt_from_audio"] | |
elif st.session_state["prompt_from_text"] != None: | |
prompt = st.session_state["prompt_from_text"] | |
if prompt != None and prompt != '': | |
st.session_state["celeb_bot"].text = prompt | |
# Display user message in chat message container | |
with dialogue_container: | |
st.chat_message("user").markdown(prompt) | |
# Add user message to chat history | |
st.session_state["messages"].append({"role": "user", "content": prompt}) | |
# Add assistant response to chat history | |
response = st.session_state["celeb_bot"].question_answer() | |
# disable autoplay to play in HTML | |
b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False) | |
md = f""" | |
<p>{response}</p> | |
<audio controls controlsList="autoplay nodownload"> | |
<source src="data:audio/wav;base64,{b64}" type="audio/wav"> | |
Your browser does not support the audio element. | |
</audio> | |
""" | |
with dialogue_container: | |
st.chat_message("assistant").markdown( | |
md, | |
unsafe_allow_html=True, | |
) | |
# Display assistant response in chat message container | |
st.session_state["messages"].append({"role": "assistant", "content": response}) | |
st.session_state["prompt_from_audio"] = "" | |
st.session_state["prompt_from_text"] = "" | |
if __name__ == "__main__": | |
main() | |