test-gpt-omni / app.py
TuringsSolutions's picture
initial commit
b6ab738 verified
raw
history blame
3.38 kB
import gradio as gr
import numpy as np
import io
import tempfile
from pydub import AudioSegment
from dataclasses import dataclass, field
import numpy as np
@dataclass
class AppState:
stream: np.ndarray | None = None
sampling_rate: int = 0
pause_detected: bool = False
stopped: bool = False
started_talking: bool = False
conversation: list = field(default_factory=list) # Use default_factory for mutable defaults
# Function to process audio input and detect pauses
def process_audio(audio: tuple, state: AppState):
if state.stream is None:
state.stream = audio[1]
state.sampling_rate = audio[0]
else:
state.stream = np.concatenate((state.stream, audio[1]))
# Custom pause detection logic (replace with actual implementation)
pause_detected = len(state.stream) > state.sampling_rate * 1 # Example: 1-sec pause
state.pause_detected = pause_detected
if state.pause_detected:
return gr.Audio(recording=False), state # Stop recording
return None, state
# Generate chatbot response from user audio input
def response(state: AppState):
if not state.pause_detected:
return None, state
# Convert user audio to WAV format
audio_buffer = io.BytesIO()
segment = AudioSegment(
state.stream.tobytes(),
frame_rate=state.sampling_rate,
sample_width=state.stream.dtype.itemsize,
channels=1 if len(state.stream.shape) == 1 else state.stream.shape[1]
)
segment.export(audio_buffer, format="wav")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(audio_buffer.getvalue())
state.conversation.append({"role": "user", "content": {"path": f.name, "mime_type": "audio/wav"}})
# Simulate chatbot's response (replace with mini omni model logic)
chatbot_response = b"Simulated response audio content" # Placeholder
output_buffer = chatbot_response # Stream actual chatbot response here
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
f.write(output_buffer)
state.conversation.append({"role": "assistant", "content": {"path": f.name, "mime_type": "audio/mp3"}})
yield None, state
# --- Gradio Interface ---
def start_recording_user(state: AppState):
if not state.stopped:
return gr.Audio(recording=True)
# Build Gradio app using Blocks API
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
with gr.Column():
chatbot = gr.Chatbot(label="Conversation", type="messages")
output_audio = gr.Audio(label="Output Audio", streaming=True, autoplay=True)
state = gr.State(value=AppState())
stream = input_audio.stream(
process_audio, [input_audio, state], [input_audio, state], stream_every=0.5, time_limit=30
)
respond = input_audio.stop_recording(response, [state], [output_audio, state])
respond.then(lambda s: s.conversation, [state], [chatbot])
restart = output_audio.stop(start_recording_user, [state], [input_audio])
cancel = gr.Button("Stop Conversation", variant="stop")
cancel.click(lambda: (AppState(stopped=True), gr.Audio(recording=False)), None, [state, input_audio], cancels=[respond, restart])
if __name__ == "__main__":
demo.launch()