Live2Diff / demo /app.py
leoxing1996
add demo
d16b52d
import logging
import mimetypes
import os
import time
import uuid
from types import SimpleNamespace
import markdown2
import torch
from config import Args, config
from connection_manager import ConnectionManager, ServerFullException
from fastapi import FastAPI, HTTPException, Request, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from util import bytes_to_pil, pil_to_frame
from vid2vid import Pipeline
# fix mime error on windows
mimetypes.add_type("application/javascript", ".js")
THROTTLE = 1.0 / 120
# logging.basicConfig(level=logging.DEBUG)
class App:
def __init__(self, config: Args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16
pipeline = Pipeline(config, device, torch_dtype)
self.args = config
self.pipeline = pipeline
self.app = FastAPI()
self.conn_manager = ConnectionManager()
self.init_app()
def init_app(self):
self.app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@self.app.websocket("/api/ws/{user_id}")
async def websocket_endpoint(user_id: uuid.UUID, websocket: WebSocket):
try:
await self.conn_manager.connect(user_id, websocket, self.args.max_queue_size)
await handle_websocket_data(user_id)
except ServerFullException as e:
logging.error(f"Server Full: {e}")
finally:
await self.conn_manager.disconnect(user_id)
logging.info(f"User disconnected: {user_id}")
async def handle_websocket_data(user_id: uuid.UUID):
if not self.conn_manager.check_user(user_id):
return HTTPException(status_code=404, detail="User not found")
last_time = time.time()
try:
while True:
if self.args.timeout > 0 and time.time() - last_time > self.args.timeout:
await self.conn_manager.send_json(
user_id,
{
"status": "timeout",
"message": "Your session has ended",
},
)
await self.conn_manager.disconnect(user_id)
return
data = await self.conn_manager.receive_json(user_id)
if data["status"] == "next_frame":
info = self.pipeline.Info()
params = await self.conn_manager.receive_json(user_id)
params = self.pipeline.InputParams(**params)
params = SimpleNamespace(**params.model_dump())
if info.input_mode == "image":
image_data = await self.conn_manager.receive_bytes(user_id)
if len(image_data) == 0:
await self.conn_manager.send_json(user_id, {"status": "send_frame"})
continue
params.image = bytes_to_pil(image_data)
await self.conn_manager.update_data(user_id, params)
except Exception as e:
logging.error(f"Websocket Error: {e}, {user_id} ")
await self.conn_manager.disconnect(user_id)
@self.app.get("/api/queue")
async def get_queue_size():
queue_size = self.conn_manager.get_user_count()
return JSONResponse({"queue_size": queue_size})
@self.app.get("/api/stream/{user_id}")
async def stream(user_id: uuid.UUID, request: Request):
try:
async def generate():
while True:
last_time = time.time()
await self.conn_manager.send_json(user_id, {"status": "send_frame"})
params = await self.conn_manager.get_latest_data(user_id)
if params is None:
continue
image = self.pipeline.predict(params)
if image is None:
continue
frame = pil_to_frame(image)
yield frame
if self.args.debug:
print(f"Time taken: {time.time() - last_time}")
return StreamingResponse(
generate(),
media_type="multipart/x-mixed-replace;boundary=frame",
headers={"Cache-Control": "no-cache"},
)
except Exception as e:
logging.error(f"Streaming Error: {e}, {user_id} ")
return HTTPException(status_code=404, detail="User not found")
# route to setup frontend
@self.app.get("/api/settings")
async def settings():
info_schema = self.pipeline.Info.model_json_schema()
info = self.pipeline.Info()
if info.page_content:
page_content = markdown2.markdown(info.page_content)
input_params = self.pipeline.InputParams.model_json_schema()
return JSONResponse(
{
"info": info_schema,
"input_params": input_params,
"max_queue_size": self.args.max_queue_size,
"page_content": page_content if info.page_content else "",
}
)
if not os.path.exists("public"):
os.makedirs("public")
self.app.mount("/", StaticFiles(directory="./frontend/public", html=True), name="public")
app = App(config).app