import copy import base64 import numpy as np import streamlit as st from src.generation import MAX_AUDIO_LENGTH from src.utils import bytes_to_array, array_to_bytes from src.content.common import ( DEFAULT_DIALOGUE_STATES, init_state_section, header_section, sidebar_fragment, retrive_response_with_ui ) # TODO: change this. DEFAULT_PROMPT = "Please follow the instruction in the speech." def _update_audio(audio_bytes): origin_audio_array = bytes_to_array(audio_bytes) truncated_audio_array = origin_audio_array[: MAX_AUDIO_LENGTH*16000] truncated_audio_bytes = array_to_bytes(truncated_audio_array) st.session_state.vc_audio_array = origin_audio_array st.session_state.vc_audio_base64 = base64.b64encode(truncated_audio_bytes).decode('utf-8') @st.dialog("Specify Audio") def audio_attach_dialogue(): st.markdown("**Upload**") uploaded_file = st.file_uploader( label="**Upload Audio:**", label_visibility="collapsed", type=['wav', 'mp3'], on_change=lambda: st.session_state.update( on_upload=True, vc_messages=[], disprompt=True ), key='upload' ) if uploaded_file and st.session_state.on_upload: audio_bytes = uploaded_file.read() _update_audio(audio_bytes) st.session_state.update( on_upload=False, new_prompt=DEFAULT_PROMPT ) st.rerun() def bottom_input_section(): st.info(":bulb: Ask something with clear intention.") bottom_cols = st.columns([0.03, 0.03, 0.94]) with bottom_cols[0]: st.button( 'Clear', disabled=st.session_state.disprompt, on_click=lambda: st.session_state.update(copy.deepcopy(DEFAULT_DIALOGUE_STATES)) ) with bottom_cols[1]: if st.button("\+ Audio", disabled=st.session_state.disprompt): audio_attach_dialogue() with bottom_cols[2]: uploaded_file = st.audio_input( label="record audio", label_visibility="collapsed", on_change=lambda: st.session_state.update( on_record=True, vc_messages=[], disprompt=True ), key='record' ) if uploaded_file and st.session_state.on_record: audio_bytes = uploaded_file.read() _update_audio(audio_bytes) st.session_state.update( on_record=False, new_prompt=DEFAULT_PROMPT ) def conversation_section(): for message in st.session_state.vc_messages: with st.chat_message(message["role"]): if message.get("error"): st.error(message["error"]) for warning_msg in message.get("warnings", []): st.warning(warning_msg) if message.get("audio", np.array([])).shape[0]: st.audio(message["audio"], format="audio/wav", sample_rate=16000) if message.get("content"): st.write(message["content"]) with st._bottom: bottom_input_section() if one_time_prompt := st.session_state.new_prompt: one_time_array = st.session_state.vc_audio_array one_time_base64 = st.session_state.vc_audio_base64 st.session_state.update( new_prompt="", one_time_array=np.array([]), one_time_base64="", vc_messages=[] ) with st.chat_message("user"): st.audio(one_time_array, format="audio/wav", sample_rate=16000) st.session_state.vc_messages.append({"role": "user", "audio": one_time_array}) with st.chat_message("assistant"): with st.spinner("Thinking..."): error_msg, warnings, response = retrive_response_with_ui( one_time_prompt, one_time_array, one_time_base64, stream=True ) st.session_state.vc_messages.append({ "role": "assistant", "error": error_msg, "warnings": warnings, "content": response }) st.session_state.disprompt=False st.rerun(scope="app") def voice_chat_page(): init_state_section() header_section(component_name="Voice Chat") with st.sidebar: sidebar_fragment() conversation_section()