Spaces:
Sleeping
Sleeping
File size: 3,840 Bytes
b0f3145 8425644 b0f3145 84c6537 88dc089 b0f3145 8425644 b0f3145 23f3ac6 7c672bb 23f3ac6 7c672bb 23f3ac6 b0f3145 23f3ac6 b0f3145 75859e2 c811b57 7c672bb b0f3145 1db955a c811b57 7c672bb f1e3c7d 8425644 4f9929e 1db955a 5db2f57 4f9929e b0f3145 7c672bb 7e06c4d 88dc089 7e06c4d b0f3145 0ce0e61 b0f3145 7c672bb b0f3145 75859e2 b0f3145 75859e2 b0f3145 75859e2 bef93f3 75859e2 bef93f3 75859e2 bef93f3 b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
assert torch.cuda.is_available()
device = "cuda"
dtype = torch.float16
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
"1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
"2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
"4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}
# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
global step_loaded
print(prompt, option)
ckpt, step = opts[option]
progress((0, step))
if step != step_loaded:
print(f"Switching checkpoint from {step_loaded} to {step}")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
step_loaded = step
def inference_callback(p, i, t, kwargs):
progress((i+1, step))
return kwargs
return pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback).images[0]
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>SDXL-Lightning</center></h1>" +
"<p><center>Lightning-fast text-to-image generation</center></p>" +
"<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
)
with gr.Row():
prompt = gr.Textbox(
label="Text prompt",
scale=8
)
option = gr.Dropdown(
label="Inference steps",
choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
value="4 Steps",
interactive=True
)
submit = gr.Button(
scale=1,
variant="primary"
)
img = gr.Image(label="SDXL-Lighting Generated Image")
prompt.submit(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
submit.click(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
gr.Examples(
fn=generate,
examples=[
["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
["A lion in the galaxy, octane render", "2 Steps"],
["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
["A girl smiling", "4 Steps"],
["An astronaut riding a horse", "4 Steps"],
["A fish on a bicycle, colorful art", "4 Steps"],
["A close-up of an Asian lady with sunglasses.", "4 Steps"],
["Man portrait, ethereal", "8 Steps"],
["Rabbit portrait in a forest, fantasy", "8 Steps"],
["A panda swimming", "8 Steps"],
],
inputs=[prompt, option],
outputs=img,
cache_examples=True,
)
gr.HTML(
"<p><small><center>This demo is built together by the community</center></small></p>" +
)
demo.queue().launch() |