Spaces:
Running
Running
# handler.py | |
import asyncio | |
import base64 | |
import json | |
import os | |
import traceback | |
from websockets.asyncio.client import connect | |
import os | |
# Load environment variables from a .env file only if running locally | |
if not os.getenv('GOOGLE_API_KEY'): | |
from dotenv import load_dotenv | |
load_dotenv() | |
host = "generativelanguage.googleapis.com" | |
model = "gemini-2.0-flash-exp" | |
api_key = os.environ["GOOGLE_API_KEY"] | |
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}" | |
class AudioLoop: | |
def __init__(self): | |
self.ws = None | |
# Queue for messages to be sent *to* Gemini | |
self.out_queue = asyncio.Queue() | |
# Queue for PCM audio received *from* Gemini | |
self.audio_in_queue = asyncio.Queue() | |
async def startup(self, tools=None): | |
"""Send the model setup message to Gemini. | |
Args: | |
tools: Optional list of tools to enable for the model | |
""" | |
setup_msg = {"setup": {"model": f"models/{model}"}} | |
if tools: | |
setup_msg["setup"]["tools"] = tools | |
await self.ws.send(json.dumps(setup_msg)) | |
raw_response = await self.ws.recv() | |
setup_response = json.loads(raw_response) | |
print("[AudioLoop] Setup response from Gemini:", setup_response) | |
async def send_realtime(self): | |
"""Read from out_queue and forward those messages to Gemini in real time.""" | |
while True: | |
msg = await self.out_queue.get() | |
await self.ws.send(json.dumps(msg)) | |
async def receive_audio(self): | |
"""Read from Gemini websocket and push PCM data into audio_in_queue.""" | |
async for raw_response in self.ws: | |
response = json.loads(raw_response) | |
# Debug log all responses (optional) | |
# print("Gemini raw response:", response) | |
# Check if there's inline PCM data | |
try: | |
b64data = ( | |
response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"] | |
) | |
pcm_data = base64.b64decode(b64data) | |
# Send audio with type "audio" | |
await self.audio_in_queue.put({ | |
"type": "audio", | |
"payload": base64.b64encode(pcm_data).decode() | |
}) | |
except KeyError: | |
# No audio in this message | |
pass | |
# Forward function calls to client | |
tool_call = response.pop('toolCall', None) | |
if tool_call is not None: | |
await self.audio_in_queue.put({ | |
"type": "function_call", | |
"payload": tool_call | |
}) | |
# If "turnComplete" is present | |
if "serverContent" in response and response["serverContent"].get("turnComplete"): | |
print("[AudioLoop] Gemini turn complete") | |
async def handle_tool_call(self, tool_call_response): | |
"""Handle tool call response from client""" | |
msg = { | |
'tool_response': { | |
'function_responses': [{ | |
'id': tool_call_response['id'], | |
'name': tool_call_response['name'], | |
'response': tool_call_response['response'] | |
}] | |
} | |
} | |
await self.ws.send(json.dumps(msg)) | |
async def run(self): | |
"""Main entry point: connects to Gemini, starts send/receive tasks.""" | |
try: | |
turn_on_the_lights_schema = {'name': 'turn_on_the_lights'} | |
turn_off_the_lights_schema = {'name': 'turn_off_the_lights'} | |
change_background_schema = { | |
'name': 'change_background', | |
'description': 'Changes the background color of the webpage', | |
'parameters': { | |
'type': 'object', | |
'properties': { | |
'color': { | |
'type': 'string', | |
'description': 'Color to change the background to (e.g., red, blue, #FF0000)' | |
} | |
} | |
} | |
} | |
tools = [ | |
{'google_search': {}}, | |
{'function_declarations': [ | |
turn_on_the_lights_schema, | |
turn_off_the_lights_schema, | |
change_background_schema | |
]}, | |
{'code_execution': {}}, | |
] | |
async with connect(uri, additional_headers={"Content-Type": "application/json"}) as ws: | |
self.ws = ws | |
await self.startup(tools) | |
try: | |
async with asyncio.TaskGroup() as tg: | |
send_task = tg.create_task(self.send_realtime()) | |
receive_task = tg.create_task(self.receive_audio()) | |
await asyncio.Future() # Keep running until canceled | |
finally: | |
# Clean up tasks and connection | |
if 'send_task' in locals(): | |
send_task.cancel() | |
if 'receive_task' in locals(): | |
receive_task.cancel() | |
try: | |
await self.ws.close() | |
print("[AudioLoop] Closed WebSocket connection") | |
except Exception as e: | |
print(f"[AudioLoop] Error closing Gemini connection: {e}") | |
print("[AudioLoop] Cleanup complete") | |
except asyncio.CancelledError: | |
print("[AudioLoop] Cancelled") | |
except Exception as e: | |
traceback.print_exc() | |
raise | |