Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -19,22 +19,24 @@ opts = {
|
|
19 |
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
|
20 |
}
|
21 |
|
22 |
-
step_loaded =
|
23 |
-
|
24 |
-
|
25 |
-
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=dtype, variant="fp16").to(device)
|
26 |
|
27 |
@spaces.GPU(enable_queue=True)
|
28 |
def generate_image(prompt, option):
|
29 |
global step_loaded
|
|
|
30 |
ckpt, step = opts[option]
|
31 |
if step != step_loaded:
|
|
|
32 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
|
33 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
34 |
step_loaded = step
|
35 |
return pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
|
36 |
|
37 |
-
with gr.Blocks() as demo:
|
38 |
gr.HTML(
|
39 |
"<h1><center>SDXL-Lightning</center></h1>" +
|
40 |
"<p><center>Lightning-fast text-to-image generation</center></p>" +
|
@@ -57,7 +59,7 @@ with gr.Blocks() as demo:
|
|
57 |
variant="primary"
|
58 |
)
|
59 |
|
60 |
-
img = gr.Image(label="SDXL-
|
61 |
|
62 |
prompt.submit(
|
63 |
fn=generate_image,
|
|
|
19 |
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
|
20 |
}
|
21 |
|
22 |
+
step_loaded = 4
|
23 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
|
24 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
|
25 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
|
26 |
|
27 |
@spaces.GPU(enable_queue=True)
|
28 |
def generate_image(prompt, option):
|
29 |
global step_loaded
|
30 |
+
print(prompt, option)
|
31 |
ckpt, step = opts[option]
|
32 |
if step != step_loaded:
|
33 |
+
print(f"Switching checkpoint from {step_loaded} to {step}")
|
34 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
|
35 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
36 |
step_loaded = step
|
37 |
return pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
|
38 |
|
39 |
+
with gr.Blocks(css="style.css") as demo:
|
40 |
gr.HTML(
|
41 |
"<h1><center>SDXL-Lightning</center></h1>" +
|
42 |
"<p><center>Lightning-fast text-to-image generation</center></p>" +
|
|
|
59 |
variant="primary"
|
60 |
)
|
61 |
|
62 |
+
img = gr.Image(label="SDXL-Lighting Generated Image")
|
63 |
|
64 |
prompt.submit(
|
65 |
fn=generate_image,
|