Live2Diff / demo /connection_manager.py
leoxing1996
add demo
d16b52d
import asyncio
import logging
from types import SimpleNamespace
from typing import Dict, Union
from uuid import UUID
from fastapi import WebSocket
from starlette.websockets import WebSocketState
Connections = Dict[UUID, Dict[str, Union[WebSocket, asyncio.Queue]]]
class ServerFullException(Exception):
"""Exception raised when the server is full."""
pass
class ConnectionManager:
def __init__(self):
self.active_connections: Connections = {}
async def connect(self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
if max_queue_size > 0 and user_count >= max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
raise ServerFullException("Server is full")
print(f"New user connected: {user_id}")
self.active_connections[user_id] = {
"websocket": websocket,
"queue": asyncio.Queue(),
}
await websocket.send_json(
{"status": "connected", "message": "Connected"},
)
await websocket.send_json({"status": "wait"})
await websocket.send_json({"status": "send_frame"})
def check_user(self, user_id: UUID) -> bool:
return user_id in self.active_connections
async def update_data(self, user_id: UUID, new_data: SimpleNamespace):
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
await queue.put(new_data)
async def get_latest_data(self, user_id: UUID) -> SimpleNamespace:
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
try:
return await queue.get()
except asyncio.QueueEmpty:
return None
def delete_user(self, user_id: UUID):
user_session = self.active_connections.pop(user_id, None)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
def get_user_count(self) -> int:
return len(self.active_connections)
def get_websocket(self, user_id: UUID) -> WebSocket:
user_session = self.active_connections.get(user_id)
if user_session:
websocket = user_session["websocket"]
if websocket.client_state == WebSocketState.CONNECTED:
return user_session["websocket"]
return None
async def disconnect(self, user_id: UUID):
websocket = self.get_websocket(user_id)
if websocket:
await websocket.close()
self.delete_user(user_id)
async def send_json(self, user_id: UUID, data: Dict):
try:
websocket = self.get_websocket(user_id)
if websocket:
await websocket.send_json(data)
except Exception as e:
logging.error(f"Error: Send json: {e}")
async def receive_json(self, user_id: UUID) -> Dict:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_json()
except Exception as e:
logging.error(f"Error: Receive json: {e}")
async def receive_bytes(self, user_id: UUID) -> bytes:
try:
websocket = self.get_websocket(user_id)
if websocket:
return await websocket.receive_bytes()
except Exception as e:
logging.error(f"Error: Receive bytes: {e}")