import threading from collections import deque from dataclasses import dataclass from typing import Optional import gradio as gr from PIL import Image from constants import DESCRIPTION, LOGO from gradio_examples import EXAMPLES from model import get_pipeline from utils import replace_background MAX_QUEUE_SIZE = 4 pipeline = get_pipeline() @dataclass class GenerationState: prompts: deque generations: deque def get_initial_state() -> GenerationState: return GenerationState( prompts=deque(maxlen=MAX_QUEUE_SIZE), generations=deque(maxlen=MAX_QUEUE_SIZE), ) def load_initial_state(request: gr.Request) -> GenerationState: print("Loading initial state for", request.client.host) print("Total number of active threads", threading.active_count()) return get_initial_state() async def put_to_queue( image: Optional[Image.Image], prompt: str, seed: int, strength: float, state: GenerationState, ): prompts_queue = state.prompts if prompt and image is not None: prompts_queue.append((image, prompt, seed, strength)) return state def inference(state: GenerationState) -> Image.Image: prompts_queue = state.prompts generations_queue = state.generations if len(prompts_queue) == 0: return state image, prompt, seed, strength = prompts_queue.popleft() original_image_size = image.size image = replace_background(image.resize((512, 512))) result = pipeline( prompt=prompt, image=image, strength=strength, seed=seed, guidance_scale=1, num_inference_steps=4, ) output_image = result.images[0].resize(original_image_size) generations_queue.append(output_image) return state def update_output_image(state: GenerationState): image_update = gr.update() generations_queue = state.generations if len(generations_queue) > 0: generated_image = generations_queue.popleft() image_update = gr.update(value=generated_image) return image_update, state with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: generation_state = gr.State(get_initial_state()) gr.HTML(f'