Spaces:
Sleeping
Sleeping
File size: 3,145 Bytes
b0f3145 8425644 b0f3145 84c6537 88dc089 b0f3145 8425644 b0f3145 23f3ac6 7c672bb 23f3ac6 7c672bb 23f3ac6 b0f3145 23f3ac6 b0f3145 75859e2 c811b57 7c672bb b0f3145 1db955a c811b57 7c672bb f1e3c7d 8425644 4f9929e 1db955a 5db2f57 4f9929e b0f3145 7c672bb 7e06c4d 88dc089 7e06c4d b0f3145 0ce0e61 b0f3145 7c672bb b0f3145 75859e2 b0f3145 75859e2 b0f3145 75859e2 b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
assert torch.cuda.is_available()
device = "cuda"
dtype = torch.float16
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
"1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
"2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
"4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}
# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
global step_loaded
print(prompt, option)
ckpt, step = opts[option]
progress((0, step))
if step != step_loaded:
print(f"Switching checkpoint from {step_loaded} to {step}")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
step_loaded = step
def inference_callback(p, i, t, kwargs):
progress((i+1, step))
return kwargs
return pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback).images[0]
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>SDXL-Lightning</center></h1>" +
"<p><center>Lightning-fast text-to-image generation</center></p>" +
"<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
)
with gr.Row():
prompt = gr.Textbox(
label="Text prompt",
scale=8
)
option = gr.Dropdown(
label="Inference steps",
choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
value="4 Steps",
interactive=True
)
submit = gr.Button(
scale=1,
variant="primary"
)
img = gr.Image(label="SDXL-Lighting Generated Image")
prompt.submit(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
submit.click(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
gr.Examples(
fn=generate,
examples=[
["A girl smiling", "4 Steps"],
["An astronaut riding a horse", "4 Steps"]
],
inputs=[prompt, option],
outputs=img,
cache_examples=True,
)
demo.queue().launch() |