from celebbot import CelebBot import streamlit as st from streamlit_mic_recorder import speech_to_text from utils import * def main(): hide_footer() model_list = ["flan-t5-large", "flan-t5-xl", "Falcon-7b-instruct"] celeb_data = get_celeb_data(f'data.json') st.sidebar.header("CelebChat") expander = st.sidebar.expander('About the app') with expander: st.markdown("This app is a demo of celebrity chatting!") if "messages" not in st.session_state: st.session_state["messages"] = [] if "QA_model_path" not in st.session_state: st.session_state["QA_model_path"] = "google/flan-t5-base" if "sentTr_model_path" not in st.session_state: st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2" if "start_chat" not in st.session_state: st.session_state["start_chat"] = False if "prompt" not in st.session_state: st.session_state["prompt"] = None def start_chat(name, model_id): print(name, model_id) if name != '' and model_id != '': st.session_state["start_chat"] = True else: st.session_state["start_chat"] = False with st.sidebar.form("my_form"): print("enter form") st.session_state["celeb_name"] = st.selectbox('Choose a celebrity', options=list(celeb_data.keys())) model_id=st.selectbox("Choose Your Flan-T5 model",options=model_list) st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id st.form_submit_button(label="Start Chatting", on_click=start_chat, args=(st.session_state["celeb_name"], st.session_state["QA_model_path"])) if st.session_state["start_chat"]: celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"] knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"] st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], get_tokenizer(st.session_state["QA_model_path"]), get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), get_tokenizer(st.session_state["sentTr_model_path"]), get_auto_model(st.session_state["sentTr_model_path"]), *preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm") ) dialogue_container = st.container() with dialogue_container: for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) if "_last_audio_id" not in st.session_state: st.session_state["_last_audio_id"] = 0 with st.sidebar: prompt_from_audio =speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT') prompt_from_text = st.text_input('Or write something') if prompt_from_audio != None: st.session_state["prompt"] = prompt_from_audio elif prompt_from_text != None: st.session_state["prompt"] = prompt_from_text print(st.session_state["prompt"]) if st.session_state["prompt"] != None and st.session_state["prompt"] != '': st.session_state["celeb_bot"].text = st.session_state["prompt"] # Display user message in chat message container with dialogue_container: st.chat_message("user").markdown(st.session_state["prompt"]) # Add user message to chat history st.session_state["messages"].append({"role": "user", "content": st.session_state["prompt"]}) # Add assistant response to chat history response = st.session_state["celeb_bot"].question_answer() # disable autoplay to play in HTML b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False) md = f"""
{response}
""" with dialogue_container: st.chat_message("assistant").markdown( md, unsafe_allow_html=True, ) # Display assistant response in chat message container st.session_state["messages"].append({"role": "assistant", "content": response}) if __name__ == "__main__": main()