|
"""A simple web interactive chat demo based on gradio.""" |
|
|
|
import os |
|
import gradio as gr |
|
from gradio_webrtc import WebRTC, AdditionalOutputs, ReplyOnPause |
|
import base64 |
|
import numpy as np |
|
import requests |
|
import io |
|
from pydub import AudioSegment |
|
|
|
|
|
API_URL = os.getenv("API_URL", None) |
|
client = None |
|
|
|
if API_URL is None: |
|
from inference import OmniInference |
|
omni_client = OmniInference('./checkpoint', 'cuda:0') |
|
omni_client.warm_up() |
|
|
|
|
|
OUT_CHUNK = 4096 |
|
OUT_RATE = 24000 |
|
OUT_CHANNELS = 1 |
|
|
|
|
|
|
|
account_sid = os.environ.get("TWILIO_ACCOUNT_SID") |
|
auth_token = os.environ.get("TWILIO_AUTH_TOKEN") |
|
|
|
if account_sid and auth_token: |
|
from twilio.rest import Client |
|
client = Client(account_sid, auth_token) |
|
|
|
token = client.tokens.create() |
|
|
|
rtc_configuration = { |
|
"iceServers": token.ice_servers, |
|
"iceTransportPolicy": "relay", |
|
} |
|
else: |
|
rtc_configuration = None |
|
|
|
|
|
def response(audio: tuple[int, np.ndarray], conversation: list[dict], img: str | None): |
|
|
|
sampling_rate, audio_np = audio |
|
audio_np = audio_np.squeeze() |
|
|
|
audio_buffer = io.BytesIO() |
|
segment = AudioSegment( |
|
audio_np.tobytes(), |
|
frame_rate=sampling_rate, |
|
sample_width=audio_np.dtype.itemsize, |
|
channels=1, |
|
) |
|
|
|
segment.export(audio_buffer, format="wav") |
|
conversation.append({"role": "user", "content": gr.Audio((sampling_rate, audio_np))}) |
|
conversation.append({"role": "assistant", "content": ""}) |
|
|
|
base64_encoded = str(base64.b64encode(audio_buffer.getvalue()), encoding="utf-8") |
|
if API_URL is not None: |
|
output_audio_bytes = b"" |
|
files = {"audio": base64_encoded} |
|
if img is not None: |
|
files["image"] = str(base64.b64encode(open(img, "rb").read()), encoding="utf-8") |
|
print("sending request to server") |
|
resp_text = "" |
|
with requests.post(API_URL, json=files, stream=True) as response: |
|
try: |
|
buffer = b'' |
|
for chunk in response.iter_content(chunk_size=2048): |
|
buffer += chunk |
|
while b'\r\n--frame\r\n' in buffer: |
|
frame, buffer = buffer.split(b'\r\n--frame\r\n', 1) |
|
if b'Content-Type: audio/wav' in frame: |
|
audio_data = frame.split(b'\r\n\r\n', 1)[1] |
|
|
|
output_audio_bytes += audio_data |
|
audio_array = np.frombuffer(audio_data, dtype=np.int16).reshape(1, -1) |
|
yield (OUT_RATE, audio_array, "mono") |
|
elif b'Content-Type: text/plain' in frame: |
|
text_data = frame.split(b'\r\n\r\n', 1)[1].decode() |
|
resp_text += text_data |
|
if len(text_data) > 0: |
|
conversation[-1]["content"] = resp_text |
|
yield AdditionalOutputs(conversation) |
|
except Exception as e: |
|
raise Exception(f"Error during audio streaming: {e}") from e |
|
|
|
|
|
def main(port=None): |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center'> |
|
Mini-Omni-2 Chat (Powered by WebRTC ⚡️) |
|
</h1> |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
audio = WebRTC( |
|
label="Stream", |
|
rtc_configuration=rtc_configuration, |
|
mode="send-receive", |
|
modality="audio", |
|
) |
|
with gr.Column(): |
|
img = gr.Image(label="Image", type="filepath") |
|
with gr.Column(): |
|
conversation = gr.Chatbot(label="Conversation", type="messages") |
|
|
|
audio.stream( |
|
fn=ReplyOnPause( |
|
response, output_sample_rate=OUT_RATE, output_frame_size=480 |
|
), |
|
inputs=[audio, conversation, img], |
|
outputs=[audio], |
|
time_limit=90, |
|
) |
|
audio.on_additional_outputs(lambda c: c, outputs=[conversation]) |
|
if port is not None: |
|
demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port) |
|
else: |
|
demo.queue().launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
import fire |
|
|
|
fire.Fire(main) |
|
|