Spaces:
Sleeping
Sleeping
File size: 4,516 Bytes
b0f3145 8425644 b0f3145 0b120d5 b0f3145 84c6537 88dc089 b0f3145 8425644 b0f3145 23f3ac6 7c672bb 23f3ac6 7c672bb 23f3ac6 b0f3145 edf024c 23f3ac6 b0f3145 75859e2 c811b57 7c672bb b0f3145 edf024c 0b120d5 1db955a c811b57 7c672bb f1e3c7d 8425644 4f9929e 1db955a 5db2f57 0b120d5 b0f3145 7c672bb 7e06c4d 88dc089 7e06c4d b0f3145 0ce0e61 b0f3145 a364bc6 b0f3145 75859e2 b0f3145 75859e2 b0f3145 75859e2 bef93f3 75859e2 bef93f3 75859e2 bef93f3 bbd2321 bef93f3 b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
assert torch.cuda.is_available()
device = "cuda"
dtype = torch.float16
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
"1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
"2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
"4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}
# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
with open("filter.txt") as f:
filter_words = {word for word in f.read().split("\n") if word}
# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
global step_loaded
print(prompt, option)
ckpt, step = opts[option]
if any(word in prompt for word in filter_words):
gr.Warning("Safety checker triggered.")
print(f"Safety checker triggered on prompt: {prompt}")
return Image.new("RGB", (512, 512))
progress((0, step))
if step != step_loaded:
print(f"Switching checkpoint from {step_loaded} to {step}")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
step_loaded = step
def inference_callback(p, i, t, kwargs):
progress((i+1, step))
return kwargs
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback)
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
gr.Warning("Safety checker triggered.")
print(f"Safety checker triggered on prompt: {prompt}")
return Image.new("RGB", (512, 512))
return results.images[0]
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>SDXL-Lightning</center></h1>" +
"<p><center>Lightning-fast text-to-image generation</center></p>" +
"<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
)
with gr.Row():
prompt = gr.Textbox(
label="Text prompt",
scale=8
)
option = gr.Dropdown(
label="Inference steps",
choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
value="4 Steps",
interactive=True
)
submit = gr.Button(
scale=1,
variant="primary"
)
img = gr.Image(label="SDXL-Lightning Generated Image")
prompt.submit(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
submit.click(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
gr.Examples(
fn=generate,
examples=[
["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
["A lion in the galaxy, octane render", "2 Steps"],
["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
["A girl smiling", "4 Steps"],
["An astronaut riding a horse", "4 Steps"],
["A fish on a bicycle, colorful art", "4 Steps"],
["A close-up of an Asian lady with sunglasses.", "4 Steps"],
["Man portrait, ethereal", "8 Steps"],
["Rabbit portrait in a forest, fantasy", "8 Steps"],
["A panda swimming", "8 Steps"],
],
inputs=[prompt, option],
outputs=img,
cache_examples=True,
)
gr.HTML(
"<p><small><center>This demo is built together by the community</center></small></p>"
)
demo.queue().launch() |