from celebbot import CelebBot import streamlit as st from streamlit_mic_recorder import speech_to_text from utils import * def main(): st.set_page_config(initial_sidebar_state="expanded") hide_footer() model_list = ["flan-t5-xl"] celeb_data = get_celeb_data(f'data.json') st.sidebar.header("CelebChat") expander = st.sidebar.expander('About the app') with expander: st.markdown("Experience the ultimate celebrity chats with this app!") expander = st.sidebar.expander('Disclaimer') with expander: st.markdown(""" * CelebChat may produce inaccurate information about people, places, or facts. * If you have any concerns about your privacy or believe that the app infringes on your rights, please contact me at liuhaozhe2000@gmail.com. I am committed to addressing your concerns and taking any necessary corrective actions. """) if "messages" not in st.session_state: st.session_state["messages"] = [] if "QA_model_path" not in st.session_state: st.session_state["QA_model_path"] = "google/flan-t5-xl" if "sentTr_model_path" not in st.session_state: st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2" if "start_chat" not in st.session_state: st.session_state["start_chat"] = False if "prompt_from_audio" not in st.session_state: st.session_state["prompt_from_audio"] = "" if "prompt_from_text" not in st.session_state: st.session_state["prompt_from_text"] = "" if "celeb_bot" not in st.session_state: st.session_state["celeb_bot"] = None def text_submit(): st.session_state["prompt_from_text"] = st.session_state.text_input st.session_state.text_input = '' def example_submit(text): st.session_state["prompt_from_text"] = text def clear_chat_his(): st.session_state["messages"] = [] st.sidebar.selectbox('Choose your celebrity crush', key="celeb_name", options=sorted(list(celeb_data.keys())), on_change=clear_chat_his) model_id=st.sidebar.selectbox("Choose Your model",options=model_list) st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"] if st.session_state["celeb_name"] == "Madonna": name = "Madonna-American-singer-and-actress" elif st.session_state["celeb_name"]== "Anne Hathaway": name = "Anne-Hathaway-American-actress" else: name="-".join(st.session_state["celeb_name"].split(" ")) knowledge = get_article(f"https://www.britannica.com/biography/{name}") st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], celeb_gender, get_tokenizer(st.session_state["QA_model_path"]), get_seq2seq_model(st.session_state["QA_model_path"], _tokenizer=get_tokenizer(st.session_state["QA_model_path"])) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), get_tokenizer(st.session_state["sentTr_model_path"]), get_auto_model(st.session_state["sentTr_model_path"]), *preprocess_text(st.session_state["celeb_name"], knowledge, "en_core_web_lg") ) dialogue_container = st.container() with dialogue_container: for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) if "_last_audio_id" not in st.session_state: st.session_state["_last_audio_id"] = 0 with st.sidebar: st.write("You can record your question...") st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT') st.text_input('Or text something...', key='text_input', on_change=text_submit) st.write("Example questions:") example1 = "Hello! Did you win an Oscar?" st.button(example1, on_click=example_submit, args=[example1]) example2 = "Hi! What is your profession?" st.button(example2, on_click=example_submit, args=[example2]) example3 = "Can you tell me about your family background?" st.button(example3, on_click=example_submit, args=[example3]) if st.session_state["prompt_from_audio"] != None: prompt = st.session_state["prompt_from_audio"] elif st.session_state["prompt_from_text"] != None: prompt = st.session_state["prompt_from_text"] if prompt != None and prompt != '': st.session_state["celeb_bot"].text = prompt # Display user message in chat message container with dialogue_container: st.chat_message("user").markdown(prompt) # Add user message to chat history st.session_state["messages"].append({"role": "user", "content": prompt}) # Add assistant response to chat history if len(st.session_state["messages"]) < 3: response = st.session_state["celeb_bot"].question_answer() else: chat_his = "Question: {q}\n\nAnswer: {a}\n\n".format(q=st.session_state["messages"][-3]["content"], a=st.session_state["messages"][-2]["content"]) response = st.session_state["celeb_bot"].question_answer(chat_his=chat_his) # disable autoplay to play in HTML b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False) md = f"""
{response}
""" with dialogue_container: st.chat_message("assistant").markdown( md, unsafe_allow_html=True, ) # Display assistant response in chat message container st.session_state["messages"].append({"role": "assistant", "content": response}) st.session_state["prompt_from_audio"] = "" st.session_state["prompt_from_text"] = "" if __name__ == "__main__": main()