Spaces:
Runtime error
Runtime error
lhzstar
commited on
Commit
·
15303cb
1
Parent(s):
436ce71
new commits
Browse files- app.py +79 -72
- celebbot.py +3 -3
- data.json +0 -0
- run_tts.py +1 -5
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from celebbot import CelebBot
|
2 |
import streamlit as st
|
|
|
3 |
from streamlit_mic_recorder import speech_to_text
|
4 |
from utils import *
|
5 |
|
@@ -7,7 +8,7 @@ from utils import *
|
|
7 |
def main():
|
8 |
|
9 |
hide_footer()
|
10 |
-
model_list = ["flan-t5-
|
11 |
celeb_data = get_celeb_data(f'data.json')
|
12 |
|
13 |
st.sidebar.header("CelebChat")
|
@@ -22,80 +23,86 @@ def main():
|
|
22 |
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
|
23 |
if "start_chat" not in st.session_state:
|
24 |
st.session_state["start_chat"] = False
|
25 |
-
if "
|
26 |
-
st.session_state["
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
st.session_state["celeb_name"] = st.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
|
38 |
-
model_id=st.selectbox("Choose Your Flan-T5 model",options=model_list)
|
39 |
-
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
|
40 |
-
|
41 |
-
st.form_submit_button(label="Start Chatting", on_click=start_chat, args=(st.session_state["celeb_name"], st.session_state["QA_model_path"]))
|
42 |
-
|
43 |
-
if st.session_state["start_chat"]:
|
44 |
-
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
|
45 |
-
knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
|
46 |
-
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
|
47 |
-
get_tokenizer(st.session_state["QA_model_path"]),
|
48 |
-
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
|
49 |
-
get_tokenizer(st.session_state["sentTr_model_path"]),
|
50 |
-
get_auto_model(st.session_state["sentTr_model_path"]),
|
51 |
-
*preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
|
52 |
-
)
|
53 |
-
|
54 |
-
dialogue_container = st.container()
|
55 |
-
with dialogue_container:
|
56 |
-
for message in st.session_state["messages"]:
|
57 |
-
with st.chat_message(message["role"]):
|
58 |
-
st.markdown(message["content"])
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
st.
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
)
|
97 |
-
# Display assistant response in chat message container
|
98 |
-
st.session_state["messages"].append({"role": "assistant", "content": response})
|
99 |
|
100 |
|
101 |
if __name__ == "__main__":
|
|
|
1 |
from celebbot import CelebBot
|
2 |
import streamlit as st
|
3 |
+
import time
|
4 |
from streamlit_mic_recorder import speech_to_text
|
5 |
from utils import *
|
6 |
|
|
|
8 |
def main():
|
9 |
|
10 |
hide_footer()
|
11 |
+
model_list = ["flan-t5-xl"]
|
12 |
celeb_data = get_celeb_data(f'data.json')
|
13 |
|
14 |
st.sidebar.header("CelebChat")
|
|
|
23 |
st.session_state["sentTr_model_path"] = "sentence-transformers/all-mpnet-base-v2"
|
24 |
if "start_chat" not in st.session_state:
|
25 |
st.session_state["start_chat"] = False
|
26 |
+
if "prompt_from_audio" not in st.session_state:
|
27 |
+
st.session_state["prompt_from_audio"] = ""
|
28 |
+
if "prompt_from_text" not in st.session_state:
|
29 |
+
st.session_state["prompt_from_text"] = ""
|
30 |
+
|
31 |
+
def text_submit():
|
32 |
+
st.session_state["prompt_from_text"] = st.session_state.widget
|
33 |
+
st.session_state.widget = ''
|
34 |
+
|
35 |
+
st.session_state["celeb_name"] = st.sidebar.selectbox('Choose a celebrity', options=list(celeb_data.keys()))
|
36 |
+
model_id=st.sidebar.selectbox("Choose Your Flan-T5 model",options=model_list)
|
37 |
+
st.session_state["QA_model_path"] = f"google/{model_id}" if "flan-t5" in model_id else model_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
celeb_gender = celeb_data[st.session_state["celeb_name"]]["gender"]
|
40 |
+
knowledge = celeb_data[st.session_state["celeb_name"]]["knowledge"]
|
41 |
+
st.session_state["celeb_bot"] = CelebBot(st.session_state["celeb_name"],
|
42 |
+
get_tokenizer(st.session_state["QA_model_path"]),
|
43 |
+
get_seq2seq_model(st.session_state["QA_model_path"]) if "flan-t5" in st.session_state["QA_model_path"] else get_causal_model(st.session_state["QA_model_path"]),
|
44 |
+
get_tokenizer(st.session_state["sentTr_model_path"]),
|
45 |
+
get_auto_model(st.session_state["sentTr_model_path"]),
|
46 |
+
*preprocess_text(st.session_state["celeb_name"], celeb_gender, knowledge, "en_core_web_sm")
|
47 |
+
)
|
48 |
|
49 |
+
dialogue_container = st.container()
|
50 |
+
with dialogue_container:
|
51 |
+
for message in st.session_state["messages"]:
|
52 |
+
with st.chat_message(message["role"]):
|
53 |
+
st.markdown(message["content"])
|
54 |
+
|
55 |
+
if "_last_audio_id" not in st.session_state:
|
56 |
+
st.session_state["_last_audio_id"] = 0
|
57 |
+
with st.sidebar:
|
58 |
+
st.session_state["prompt_from_audio"] = speech_to_text(start_prompt="Start Recording",stop_prompt="Stop Recording",language='en',use_container_width=True, just_once=True,key='STT')
|
59 |
+
st.text_input('Or write something', key='widget', on_change=text_submit)
|
60 |
+
|
61 |
+
if st.session_state["prompt_from_audio"] != None:
|
62 |
+
prompt = st.session_state["prompt_from_audio"]
|
63 |
+
elif st.session_state["prompt_from_text"] != None:
|
64 |
+
prompt = st.session_state["prompt_from_text"]
|
65 |
+
|
66 |
+
if prompt != None and prompt != '':
|
67 |
+
st.session_state["celeb_bot"].text = prompt
|
68 |
+
# Display user message in chat message container
|
69 |
+
with dialogue_container:
|
70 |
+
st.chat_message("user").markdown(prompt)
|
71 |
+
# Add user message to chat history
|
72 |
+
st.session_state["messages"].append({"role": "user", "content": prompt})
|
73 |
+
|
74 |
+
# Add assistant response to chat history
|
75 |
+
response = st.session_state["celeb_bot"].question_answer()
|
76 |
|
77 |
+
# disable autoplay to play in HTML
|
78 |
+
wav, sr = st.session_state["celeb_bot"].text_to_speech(autoplay=False)
|
79 |
+
md = f"""
|
80 |
+
<p>{response}</p>
|
81 |
+
"""
|
82 |
+
with dialogue_container:
|
83 |
+
st.chat_message("assistant").markdown(
|
84 |
+
md,
|
85 |
+
unsafe_allow_html=True,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Play the audio (non-blocking)
|
89 |
+
import sounddevice as sd
|
90 |
+
try:
|
91 |
+
sd.stop()
|
92 |
+
sd.play(wav, sr)
|
93 |
+
time_span = len(wav)//sr + 1
|
94 |
+
time.sleep(time_span)
|
95 |
+
|
96 |
+
except sd.PortAudioError as e:
|
97 |
+
print("\nCaught exception: %s" % repr(e))
|
98 |
+
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
99 |
+
except:
|
100 |
+
raise
|
101 |
+
# Display assistant response in chat message container
|
102 |
+
st.session_state["messages"].append({"role": "assistant", "content": response})
|
103 |
+
|
104 |
+
st.session_state["prompt_from_audio"] = ""
|
105 |
+
st.session_state["prompt_from_text"] = ""
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
if __name__ == "__main__":
|
celebbot.py
CHANGED
@@ -103,12 +103,12 @@ class CelebBot():
|
|
103 |
## have a conversation
|
104 |
else:
|
105 |
if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:
|
106 |
-
instruction1 = f'
|
107 |
|
108 |
knowledge = self.retrieve_knowledge_assertions()
|
109 |
else:
|
110 |
-
instruction1 = f'
|
111 |
-
query = f"{instruction1}
|
112 |
input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
|
113 |
outputs = self.QA_model.generate(input_ids, max_length=1024)
|
114 |
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
103 |
## have a conversation
|
104 |
else:
|
105 |
if re.search(re.compile(rf'\b(you|your|{self.name})\b', flags=re.IGNORECASE), self.text) != None:
|
106 |
+
instruction1 = f'You are a celebrity named {self.name}. You need to answer the question based on knowledge and commonsense.'
|
107 |
|
108 |
knowledge = self.retrieve_knowledge_assertions()
|
109 |
else:
|
110 |
+
instruction1 = f'You need to answer the question based on commonsense.'
|
111 |
+
query = f"Context: {instruction1} {knowledge}\n\nQuestion: {self.text}\n\nAnswer:"
|
112 |
input_ids = self.QA_tokenizer(f"{query}", return_tensors="pt").input_ids
|
113 |
outputs = self.QA_model.generate(input_ids, max_length=1024)
|
114 |
self.text = self.QA_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
data.json
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
run_tts.py
CHANGED
@@ -109,11 +109,7 @@ def tts(text, embed_name, nlp, autoplay=True):
|
|
109 |
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
110 |
except:
|
111 |
raise
|
112 |
-
|
113 |
-
byte_io = io.BytesIO(bytes_wav)
|
114 |
-
write(byte_io, synthesizer.sample_rate, wav.astype(np.float32))
|
115 |
-
result_bytes = byte_io.read()
|
116 |
-
return base64.b64encode(result_bytes).decode()
|
117 |
|
118 |
|
119 |
if __name__ == "__main__":
|
|
|
109 |
print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
|
110 |
except:
|
111 |
raise
|
112 |
+
return wav, synthesizer.sample_rate
|
|
|
|
|
|
|
|
|
113 |
|
114 |
|
115 |
if __name__ == "__main__":
|