# webapp.py import asyncio import base64 import json import os from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles import uvicorn from handler import AudioLoop # Import your AudioLoop from above app = FastAPI() # Mount the web_ui directory to serve static files current_dir = os.path.dirname(os.path.realpath(__file__)) app.mount("/web_ui", StaticFiles(directory=current_dir), name="web_ui") @app.get("/") async def get_index(): # Read and return the index.html file index_path = os.path.join(current_dir, "index.html") with open(index_path, "r", encoding="utf-8") as f: html_content = f.read() return HTMLResponse(content=html_content) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() print("[websocket_endpoint] Client connected.") # Create a new AudioLoop instance for this client audio_loop = AudioLoop() audio_ordering_buffer = {} expected_audio_seq = 0 # Start the AudioLoop for this client loop_task = asyncio.create_task(audio_loop.run()) print("[websocket_endpoint] Started new AudioLoop for client") async def from_client_to_gemini(): """Handles incoming messages from the client and forwards them to Gemini.""" nonlocal audio_ordering_buffer, expected_audio_seq try: while True: data = await websocket.receive_text() msg = json.loads(data) msg_type = msg.get("type") #print("[from_client_to_gemini] Received message from client:", msg) # Handle audio data from client if msg_type == "audio": # Decode base64 audio from client raw_pcm = base64.b64decode(msg["payload"]) forward_msg = { "realtime_input": { "media_chunks": [ { "data": base64.b64encode(raw_pcm).decode(), "mime_type": "audio/pcm" } ] } } # Retrieve the sequence number from the message seq = msg.get("seq") if seq is not None: # Store the message in the buffer audio_ordering_buffer[seq] = forward_msg # Forward any messages in order while expected_audio_seq in audio_ordering_buffer: msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq) await audio_loop.out_queue.put(msg_to_forward) expected_audio_seq += 1 else: # If no sequence number is provided, forward immediately await audio_loop.out_queue.put(forward_msg) # Handle text data from client elif msg_type == "text": user_text = msg.get("content", "") print("[from_client_to_gemini] Forwarding user text to Gemini:", user_text) forward_msg = { "client_content": { "turn_complete": True, "turns": [ { "role": "user", "parts": [ {"text": user_text} ] } ] } } await audio_loop.out_queue.put(forward_msg) elif msg_type == "tool_call_response": # Handle tool call response from client await audio_loop.handle_tool_call(msg["payload"]) else: print("[from_client_to_gemini] Unknown message type:", msg_type) except WebSocketDisconnect: print("[from_client_to_gemini] Client disconnected.") loop_task.cancel() except Exception as e: print("[from_client_to_gemini] Error:", e) async def from_gemini_to_client(): """Reads messages from Gemini and sends them back to the client.""" try: while True: message = await audio_loop.audio_in_queue.get() message_type = message["type"] if message_type == "audio": # Audio data is already base64 encoded from handler.py await websocket.send_text(json.dumps(message)) print("[from_gemini_to_client] Sending audio chunk to client") elif message_type == "function_call": # Forward function call to client await websocket.send_text(json.dumps(message)) print("[from_gemini_to_client] Forwarding function call to client") except WebSocketDisconnect: print("[from_gemini_to_client] Client disconnected.") audio_loop.stop() except Exception as e: print("[from_gemini_to_client] Error:", e) # Launch both tasks concurrently. If either fails or disconnects, we exit. try: await asyncio.gather( from_client_to_gemini(), from_gemini_to_client(), ) finally: print("[websocket_endpoint] WebSocket handler finished.") # Clean up the AudioLoop when the client disconnects loop_task.cancel() try: await loop_task except asyncio.CancelledError: pass print("[websocket_endpoint] Cleaned up AudioLoop for client") if __name__ == "__main__": uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)