File size: 4,844 Bytes
6bc94ac
 
436ce71
6bc94ac
 
 
 
436ce71
6bc94ac
436ce71
 
 
 
 
 
 
6bc94ac
 
 
 
 
 
 
 
436ce71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc94ac
436ce71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bc94ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from celebbot import CelebBot
import streamlit as st
from streamlit_mic_recorder import speech_to_text
from utils import *


def main():

    hide_footer()
    model_list = ["flan-t5-large", "flan-t5-xl", "Falcon-7b-instruct"]
    celeb_data = get_celeb_data(f'data.json')

    st.sidebar.header("CelebChat")
    expander = st.sidebar.expander('About the app')
    with expander:
        st.markdown("This app is a demo of celebrity chatting!")
    if "messages" not in st.session_state:
        st.session_state["messages"] = []
    if "QA_model_path" not in st.session_state:          
        st.session_state["QA_model_path"] = "google/flan-t5-base"
    if "sentTr_model_path" not in st.session_state:          
        st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
    if "start_chat" not in st.session_state:          
        st.session_state["start_chat"] = False
    if "prompt" not in st.session_state:          
        st.session_state["prompt"] = None

    def start_chat(name, model_id):
        print(name, model_id)
        if name != '' and model_id != '':
            st.session_state["start_chat"] = True
        else:
            st.session_state["start_chat"] = False

    with st.sidebar.form("my_form"):
        print("enter form")
        st.session_state["celeb_name"] = st.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
        model_id=st.selectbox("Choose Your Flan-T5 model",options=model_list)
        st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id

        st.form_submit_button(label="Start Chatting", on_click=start_chat, args=(st.session_state["celeb_name"], st.session_state["QA_model_path"]))

    if st.session_state["start_chat"]:
        celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
        knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
        st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"], 
                        get_tokenizer(st.session_state["QA_model_path"]), 
                        get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]), 
                        get_tokenizer(st.session_state["sentTr_model_path"]), 
                        get_auto_model(st.session_state["sentTr_model_path"]), 
                        *preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
                        )

        dialogue_container = st.container()
        with dialogue_container:
            for message in st.session_state["messages"]:
                with st.chat_message(message["role"]):
                    st.markdown(message["content"])


        if "_last_audio_id" not in st.session_state:
            st.session_state["_last_audio_id"] = 0
        with st.sidebar:
            prompt_from_audio =speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
            prompt_from_text = st.text_input('Or write something')
        
        if prompt_from_audio != None:
            st.session_state["prompt"] = prompt_from_audio 
        elif prompt_from_text != None:
            st.session_state["prompt"] = prompt_from_text     
        print(st.session_state["prompt"])   
        if st.session_state["prompt"] != None and st.session_state["prompt"] != '':
            st.session_state["celeb_bot"].text = st.session_state["prompt"]
            # Display user message in chat message container
            with dialogue_container:
                st.chat_message("user").markdown(st.session_state["prompt"])
            # Add user message to chat history
            st.session_state["messages"].append({"role": "user", "content": st.session_state["prompt"]})

            # Add assistant response to chat history
            response = st.session_state["celeb_bot"].question_answer()
            
            # disable autoplay to play in HTML
            b64 = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
            md = f"""
            <p>{response}</p>
            <audio controls autoplay style="display:none;">
            <source src="data:audio/wav;base64,{b64}" type="audio/wav">
            Your browser does not support the audio element.
            </audio>
            """
            with dialogue_container:
                st.chat_message("assistant").markdown(
                    md,
                    unsafe_allow_html=True,
                )
            # Display assistant response in chat message container
            st.session_state["messages"].append({"role": "assistant", "content": response})


if __name__ == "__main__":
    main()