File size: 6,190 Bytes
cb92d2b d6fedfa 46bd9ac cb92d2b 3207814 cb92d2b ff9325e 3207814 c6b09d3 cb92d2b 3207814 cb92d2b 3207814 cb92d2b 3207814 1d3190d 3207814 cb92d2b 3207814 cb92d2b 1d3190d 3207814 cb92d2b 3207814 cb92d2b 3207814 cb92d2b 1d3190d cb92d2b 3207814 cb92d2b d6fedfa cb92d2b 3207814 cb92d2b 1d3190d cb92d2b 3207814 1d3190d 3207814 1d3190d ff9325e cb92d2b 7d67dc6 cb92d2b d6fedfa 7d67dc6 cb92d2b d6fedfa cb92d2b 3207814 cb92d2b 46bd9ac 43148fd d6fedfa 46bd9ac d6fedfa 46bd9ac d6fedfa cb92d2b c6b09d3 cb92d2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import Request
import markdown2
import logging
import traceback
from config import Args
from user_queue import UserData
import uuid
import time
from types import SimpleNamespace
from util import pil_to_frame, bytes_to_pil, is_firefox
import asyncio
import os
def init_app(app: FastAPI, user_data: UserData, args: Args, pipeline):
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
user_count = user_data.get_user_count()
if args.max_queue_size > 0 and user_count >= args.max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
return
try:
user_id = uuid.uuid4()
print(f"New user connected: {user_id}")
await user_data.create_user(user_id, websocket)
await websocket.send_json(
{"status": "connected", "message": "Connected", "userId": str(user_id)}
)
await websocket.send_json(
{
"status": "send_frame",
}
)
await handle_websocket_data(user_id, websocket)
except WebSocketDisconnect as e:
logging.error(f"WebSocket Error: {e}, {user_id}")
traceback.print_exc()
finally:
print(f"User disconnected: {user_id}")
user_data.delete_user(user_id)
async def handle_websocket_data(user_id: uuid.UUID, websocket: WebSocket):
if not user_data.check_user(user_id):
return HTTPException(status_code=404, detail="User not found")
last_time = time.time()
try:
while True:
data = await websocket.receive_json()
if data["status"] != "next_frame":
asyncio.sleep(1.0 / 24)
continue
params = await websocket.receive_json()
params = pipeline.InputParams(**params)
info = pipeline.Info()
params = SimpleNamespace(**params.dict())
if info.input_mode == "image":
image_data = await websocket.receive_bytes()
params.image = bytes_to_pil(image_data)
await user_data.update_data(user_id, params)
await websocket.send_json(
{
"status": "wait",
}
)
if args.timeout > 0 and time.time() - last_time > args.timeout:
await websocket.send_json(
{
"status": "timeout",
"message": "Your session has ended",
"userId": user_id,
}
)
await websocket.close()
return
await asyncio.sleep(1.0 / 24)
except Exception as e:
logging.error(f"Error: {e}")
traceback.print_exc()
@app.get("/queue_size")
async def get_queue_size():
queue_size = user_data.get_user_count()
return JSONResponse({"queue_size": queue_size})
@app.get("/stream/{user_id}")
async def stream(user_id: uuid.UUID, request: Request):
try:
print(f"New stream request: {user_id}")
async def generate():
websocket = user_data.get_websocket(user_id)
last_params = SimpleNamespace()
while True:
params = await user_data.get_latest_data(user_id)
if not vars(params) or params.__dict__ == last_params.__dict__:
await websocket.send_json(
{
"status": "send_frame",
}
)
await asyncio.sleep(0.1)
continue
last_params = params
image = pipeline.predict(params)
if image is None:
await websocket.send_json({"status": "send_frame"})
continue
frame = pil_to_frame(image)
yield frame
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396
if not is_firefox(request.headers["user-agent"]):
yield frame
await websocket.send_json({"status": "send_frame"})
return StreamingResponse(
generate(),
media_type="multipart/x-mixed-replace;boundary=frame",
headers={"Cache-Control": "no-cache"},
)
except Exception as e:
logging.error(f"Streaming Error: {e}, {user_id} ")
traceback.print_exc()
return HTTPException(status_code=404, detail="User not found")
# route to setup frontend
@app.get("/settings")
async def settings():
info_schema = pipeline.Info.schema()
info = pipeline.Info()
if info.page_content:
page_content = markdown2.markdown(info.page_content)
input_params = pipeline.InputParams.schema()
return JSONResponse(
{
"info": info_schema,
"input_params": input_params,
"max_queue_size": args.max_queue_size,
"page_content": page_content if info.page_content else "",
}
)
if not os.path.exists("public"):
os.makedirs("public")
app.mount("/", StaticFiles(directory="public", html=True), name="public")
|