import os
import base64
import numpy as np
from openai import APIConnectionError
import streamlit as st
import streamlit.components.v1 as components
from streamlit_mic_recorder import mic_recorder
from utils import load_model, generate_response, bytes_to_array, start_server, NoAudioException
general_instructions = [
"Please transcribe this speech.",
"Please summarise this speech."
]
def audio_llm():
with st.sidebar:
st.markdown("""
""", unsafe_allow_html=True)
st.slider(label='Temperature', min_value=0.0, max_value=2.0, value=0.7, key='temperature')
st.slider(label='Top P', min_value=0.0, max_value=1.0, value=1.0, key='top_p')
if st.sidebar.button('Clear History'):
st.session_state.update(messages=[],
on_upload=False,
on_record=False,
on_select=False,
audio_array=np.array([]))
if "server" not in st.session_state:
st.session_state.server = start_server()
if "client" not in st.session_state or 'model_name' not in st.session_state:
st.session_state.client, st.session_state.model_name = load_model()
if "audio_array" not in st.session_state:
st.session_state.audio_base64 = ''
st.session_state.audio_array = np.array([])
if "default_instruction" not in st.session_state:
st.session_state.default_instruction = []
st.markdown("MERaLiON-AudioLLM ChatBot 🤖
", unsafe_allow_html=True)
st.markdown(
"""This demo is based on [MERaLiON-AudioLLM](https://huggingface.co/MERaLiON/MERaLiON-AudioLLM-Whisper-SEA-LION),
developed by I2R, A*STAR, in collaboration with AISG, Singapore.
It is tailored for Singapore’s multilingual and multicultural landscape."""
)
col1, col2, col3 = st.columns([4, 4, 1.2])
with col1:
audio_samples_w_instruct = {
'1_ASR_IMDA_PART1_ASR_v2_141' : ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
'7_ASR_IMDA_PART3_30_ASR_v2_2269': ["Need this talk written down, please."],
'17_ASR_IMDA_PART6_30_ASR_v2_1413': ["Record the spoken word in text form."],
'25_ST_COVOST2_ZH-CN_EN_ST_V2_4567': ["Please translate the given speech to English."],
'26_ST_COVOST2_EN_ZH-CN_ST_V2_5422': ["Please translate the given speech to Chinese."],
'30_SI_ALPACA-GPT4-AUDIO_SI_V2_1454': ["Please follow the instruction in the speech."],
'32_SQA_CN_COLLEDGE_ENTRANCE_ENGLISH_TEST_SQA_V2_572': ["What does the man think the woman should do at 4:00."],
'33_SQA_IMDA_PART3_30_SQA_V2_2310': ["Does Speaker2's wife cook for Speaker2 when they are at home."],
'34_SQA_IMDA_PART3_30_SQA_V2_3621': ["Does the phrase \"#gai-gai#\" have a meaning in Chinese or Hokkien language."],
'35_SQA_IMDA_PART3_30_SQA_V2_4062': ["What is the color of the vase mentioned in the dialogue."],
'36_DS_IMDA_PART4_30_DS_V2_849': ["Condense the dialogue into a concise summary highlighting major topics and conclusions."],
'39_Paralingual_IEMOCAP_ER_V2_91': ["Based on the speaker's speech patterns, what do you think they are feeling."],
'40_Paralingual_IEMOCAP_ER_V2_567': ["Based on the speaker's speech patterns, what do you think they are feeling."],
'42_Paralingual_IEMOCAP_GR_V2_320': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
'43_Paralingual_IEMOCAP_GR_V2_129': ["Is it possible for you to identify whether the speaker in this recording is male or female."],
'45_Paralingual_IMDA_PART3_30_GR_V2_12312': ["So, who's speaking in the second part of the clip?", "So, who's speaking in the first part of the clip?"],
'47_Paralingual_IMDA_PART3_30_NR_V2_10479': ["Can you guess which ethnic group this person is from based on their accent."],
'49_Paralingual_MELD_ER_V2_676': ["What emotions do you think the speaker is expressing."],
'50_Paralingual_MELD_ER_V2_692': ["Based on the speaker's speech patterns, what do you think they are feeling."],
'51_Paralingual_VOXCELEB1_GR_V2_2148': ["May I know the gender of the speaker."],
'53_Paralingual_VOXCELEB1_NR_V2_2286': ["What's the nationality identity of the speaker."],
'55_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_2': ["What impact would the growth of the healthcare sector have on the country's economy in terms of employment and growth."],
'56_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_415': ["Based on the statement, can you summarize the speaker's position on the recent controversial issues in Singapore."],
'57_SQA_PUBLIC_SPEECH_SG_TEST_SQA_V2_460': ["How does the author respond to parents' worries about masks in schools."],
'2_ASR_IMDA_PART1_ASR_v2_2258': ["Turn the spoken language into a text format.", "Please translate the content into Chinese."],
'3_ASR_IMDA_PART1_ASR_v2_2265': ["Turn the spoken language into a text format."],
'4_ASR_IMDA_PART2_ASR_v2_999' : ["Translate the spoken words into text format."],
'5_ASR_IMDA_PART2_ASR_v2_2241': ["Translate the spoken words into text format."],
'6_ASR_IMDA_PART2_ASR_v2_3409': ["Translate the spoken words into text format."],
'8_ASR_IMDA_PART3_30_ASR_v2_1698': ["Need this talk written down, please."],
'9_ASR_IMDA_PART3_30_ASR_v2_2474': ["Need this talk written down, please."],
'11_ASR_IMDA_PART4_30_ASR_v2_3771': ["Write out the dialogue as text."],
'12_ASR_IMDA_PART4_30_ASR_v2_103' : ["Write out the dialogue as text."],
'10_ASR_IMDA_PART4_30_ASR_v2_1527': ["Write out the dialogue as text."],
'13_ASR_IMDA_PART5_30_ASR_v2_1446': ["Translate this vocal recording into a textual format."],
'14_ASR_IMDA_PART5_30_ASR_v2_2281': ["Translate this vocal recording into a textual format."],
'15_ASR_IMDA_PART5_30_ASR_v2_4388': ["Translate this vocal recording into a textual format."],
'16_ASR_IMDA_PART6_30_ASR_v2_576': ["Record the spoken word in text form."],
'18_ASR_IMDA_PART6_30_ASR_v2_2834': ["Record the spoken word in text form."],
'19_ASR_AIShell_zh_ASR_v2_5044': ["Transform the oral presentation into a text document."],
'20_ASR_LIBRISPEECH_CLEAN_ASR_V2_833': ["Please provide a written transcription of the speech."],
'27_ST_COVOST2_EN_ZH-CN_ST_V2_6697': ["Please translate the given speech to Chinese."],
'28_SI_ALPACA-GPT4-AUDIO_SI_V2_299': ["Please follow the instruction in the speech."],
'29_SI_ALPACA-GPT4-AUDIO_SI_V2_750': ["Please follow the instruction in the speech."],
}
audio_sample_names = [audio_sample_name for audio_sample_name in audio_samples_w_instruct.keys()]
st.markdown("**Select Audio From Examples:**")
sample_name = st.selectbox(
label="**Select Audio:**",
label_visibility="collapsed",
options=audio_sample_names,
index=None,
placeholder="Select an audio sample:",
on_change=lambda: st.session_state.update(on_select=True, messages=[]),
key='select')
if sample_name and st.session_state.on_select:
audio_bytes = open(f"audio_samples/{sample_name}.wav", "rb").read()
st.session_state.default_instruction = audio_samples_w_instruct[sample_name]
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col2:
st.markdown("or **Upload Audio:**")
uploaded_file = st.file_uploader(
label="**Upload Audio:**",
label_visibility="collapsed",
type=['wav', 'mp3'],
on_change=lambda: st.session_state.update(on_upload=True, messages=[]),
key='upload'
)
if uploaded_file and st.session_state.on_upload:
audio_bytes = uploaded_file.read()
st.session_state.default_instruction = general_instructions
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
with col3:
st.markdown("or **Record Audio:**")
recording = mic_recorder(
format="wav",
use_container_width=True,
callback=lambda: st.session_state.update(on_record=True, messages=[]),
key='record')
if recording and st.session_state.on_record:
audio_bytes = recording["bytes"]
st.session_state.default_instruction = general_instructions
st.session_state.audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
st.session_state.audio_array = bytes_to_array(audio_bytes)
st.markdown(
"""
""",
unsafe_allow_html=True,
)
if "prompt" not in st.session_state:
st.session_state.prompt = ""
if 'disprompt' not in st.session_state:
st.session_state.disprompt = False
if "messages" not in st.session_state:
st.session_state.messages = []
if st.session_state.audio_array.size:
with st.chat_message("user"):
st.audio(st.session_state.audio_array, format="audio/wav", sample_rate=16000)
if st.session_state.audio_array.shape[0] / 16000 > 30.0:
st.warning("MERaLiON-AudioLLM can only process audio for up to 30 seconds. Audio longer than that will be truncated.")
st.session_state.update(on_upload=False, on_record=False, on_select=False)
for i, inst in enumerate(st.session_state.default_instruction):
st.button(
f"**Example Instruction {i+1}**: {inst}",
args=(inst,),
disabled=st.session_state.disprompt,
on_click=lambda p: st.session_state.update(disprompt=True, prompt=p)
)
for message in st.session_state.messages[-2:]:
with st.chat_message(message["role"]):
if message.get("error"):
st.error(message["error"])
for warning_msg in message.get("warnings", []):
st.warning(warning_msg)
if message.get("content"):
st.write(message["content"])
if prompt := st.chat_input(
placeholder="Type Your Instruction Here",
disabled=st.session_state.disprompt,
on_submit=lambda: st.session_state.update(disprompt=True)
):
st.session_state.prompt = prompt
if st.session_state.prompt:
with st.chat_message("user"):
st.write(st.session_state.prompt)
st.session_state.messages.append({"role": "user", "content": st.session_state.prompt})
with st.chat_message("assistant"):
response, error_msg, warnings = "", "", []
with st.spinner("Thinking..."):
try:
stream, warnings = generate_response(st.session_state.prompt)
for warning_msg in warnings:
st.warning(warning_msg)
response = st.write_stream(stream)
except NoAudioException:
error_msg = "Please specify audio first!"
except APIConnectionError:
error_msg = "Internet connection seems to be down. Please contact the administrator to restart the space."
except Exception as e:
error_msg = f"Caught Exception: {repr(e)}. Please contact the administrator."
st.session_state.messages.append({
"role": "assistant",
"error": error_msg,
"warnings": warnings,
"content": response
})
st.session_state.update(disprompt=False, prompt="")
st.rerun()