Spaces:
Running
Running
File size: 6,139 Bytes
85a99d8 b8ca77d 85a99d8 b8ca77d 85a99d8 dcb14fa 85a99d8 96928d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import json
from collections import deque
from dataclasses import dataclass
import threading
from typing import Optional
import gradio as gr
import websockets
from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64
from PIL import Image
from websockets.sync.client import connect
from constants import DESCRIPTION, WS_ADDRESS, LOGO
from utils import replace_background
from gradio_examples import EXAMPLES
MAX_QUEUE_SIZE = 4
@dataclass
class GenerationState:
prompts: deque
responses: deque
def get_initial_state() -> GenerationState:
return GenerationState(
prompts=deque(maxlen=MAX_QUEUE_SIZE),
responses=deque(maxlen=MAX_QUEUE_SIZE),
)
def load_initial_state(request: gr.Request) -> GenerationState:
print("Loading initial state for", request.client.host)
print("Total number of active threads", threading.active_count())
return get_initial_state()
async def put_to_queue(
image: Optional[Image.Image],
prompt: str,
seed: int,
strength: float,
state: GenerationState,
):
prompts_queue = state.prompts
if prompt and image is not None:
prompts_queue.append((image, prompt, seed, strength))
return state
def send_inference_request(state: GenerationState) -> Image.Image:
prompts_queue = state.prompts
response_queue = state.responses
if len(prompts_queue) == 0:
return state
image, prompt, seed, strength = prompts_queue.popleft()
original_image_size = image.size
image = replace_background(image.resize((512, 512)))
arguments = {
"prompt": prompt,
"image_url": encode_pil_to_base64(image),
"strength": strength,
"negative_prompt": "cartoon, illustration, animation. face. male, female",
"seed": seed,
"guidance_scale": 1,
"num_inference_steps": 4,
"sync_mode": 1,
"num_images": 1,
}
connection = connect(WS_ADDRESS)
connection.send(json.dumps(arguments))
try:
response = json.loads(connection.recv())
except websockets.exceptions.ConnectionClosedOK:
print("Connection closed, reconnecting...")
# TODO: This is a hacky way to reconnect, but it works for now
# Ideally, we should be able to reconnect to the same connection
# and not have to create a new one
connection = connect(WS_ADDRESS)
try:
response = json.loads(connection.recv())
except websockets.exceptions.ConnectionClosedOK:
print("Connection closed again, aborting...")
return state
# TODO: If a new connection is created, the response do not contain the images.
if "images" in response:
response_queue.append((response, original_image_size))
return state
def update_output_image(state: GenerationState):
image_update = gr.update()
inference_time_update = gr.update()
response_queue = state.responses
if len(response_queue) > 0:
response, original_image_size = response_queue.popleft()
generated_image = decode_base64_to_image(response["images"][0]["url"])
inference_time = response["timings"]["inference"]
image_update = gr.update(value=generated_image.resize(original_image_size))
inference_time_update = gr.update(value=round(inference_time, 4))
return image_update, inference_time_update, state
with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo:
generation_state = gr.State(get_initial_state())
gr.HTML(f'<div style="width: 70px;">{LOGO}</div>')
gr.Markdown(DESCRIPTION)
with gr.Row(variant="default"):
input_image = gr.Image(
tool="color-sketch",
source="canvas",
label="Initial Image",
type="pil",
height=512,
width=512,
brush_radius=40.0,
)
output_image = gr.Image(
label="Generated Image",
type="pil",
interactive=False,
elem_id="output_image",
)
with gr.Row():
with gr.Column(scale=23):
prompt_box = gr.Textbox(label="Prompt", value=EXAMPLES[0])
with gr.Column(scale=1):
inference_time_box = gr.Number(
label="Inference Time (s)", interactive=False
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
with gr.Column():
strength = gr.Slider(
label="Strength",
minimum=0.1,
maximum=1.0,
step=0.05,
value=0.8,
info="""
Strength of the initial image that will be applied during inference.
""",
)
with gr.Column():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2**31 - 1,
step=1,
randomize=True,
info="""
Seed for the random number generator.
""",
)
demo.load(
load_initial_state,
outputs=[generation_state],
)
demo.load(
send_inference_request,
inputs=[generation_state],
outputs=[generation_state],
every=0.1,
)
demo.load(
update_output_image,
inputs=[generation_state],
outputs=[output_image, inference_time_box, generation_state],
every=0.1,
)
for event in [input_image.change, prompt_box.change, strength.change, seed.change]:
event(
put_to_queue,
[input_image, prompt_box, seed, strength, generation_state],
[generation_state],
show_progress=False,
queue=True,
)
gr.Markdown("## Example Prompts")
gr.Examples(examples=EXAMPLES, inputs=[prompt_box], label="Examples")
if __name__ == "__main__":
demo.queue(concurrency_count=20, api_open=False).launch(max_threads=8192)
|