Spaces:
Sleeping
Sleeping
File size: 4,976 Bytes
b0f3145 8425644 6619b46 b0f3145 0b120d5 b0f3145 84c6537 88dc089 b0f3145 8425644 b0f3145 23f3ac6 7c672bb 23f3ac6 cbd20c7 23f3ac6 b0f3145 6619b46 cbd20c7 6619b46 23f3ac6 b0f3145 75859e2 c811b57 7c672bb b0f3145 1db955a 83e9007 efbe3b2 83e9007 efbe3b2 83e9007 cbd20c7 c811b57 7c672bb f1e3c7d 8425644 4f9929e 1db955a 5db2f57 87670df 6619b46 0b120d5 6619b46 0b120d5 6619b46 b0f3145 7c672bb 7e06c4d 88dc089 7e06c4d b0f3145 0ce0e61 b0f3145 a364bc6 b0f3145 75859e2 b0f3145 75859e2 b0f3145 75859e2 bef93f3 75859e2 bef93f3 0c6962b bef93f3 75859e2 d20698c 75859e2 bef93f3 bbd2321 bef93f3 b0f3145 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.image_processor import VaeImageProcessor
from transformers import CLIPImageProcessor
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
device = "cuda"
dtype = torch.float16
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
opts = {
"1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
"2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
"4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
"8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
}
# Default to load 4-step model.
step_loaded = 4
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
# Safety checker.
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_processor = VaeImageProcessor(vae_scale_factor=8)
# Inference function.
@spaces.GPU(enable_queue=True)
def generate(prompt, option, progress=gr.Progress()):
global step_loaded
print(prompt, option)
ckpt, step = opts[option]
progress((0, step))
if pipe.device.type != device:
pipe.to(device, dtype)
if safety_checker.device.type != device:
safety_checker.to(device, dtype)
if step != step_loaded:
print(f"Switching checkpoint from {step_loaded} to {step}")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
step_loaded = step
def inference_callback(p, i, t, kwargs):
progress((i+1, step))
return kwargs
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
# Safety check
feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
images, has_nsfw_concept = safety_checker(
images=results.images, clip_input=safety_checker_input.pixel_values.to(device, dtype)
)
if has_nsfw_concept[0]:
gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
print(f"Safety checker triggered on prompt: {prompt}")
return images[0]
with gr.Blocks(css="style.css") as demo:
gr.HTML(
"<h1><center>SDXL-Lightning</center></h1>" +
"<p><center>Lightning-fast text-to-image generation</center></p>" +
"<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
)
with gr.Row():
prompt = gr.Textbox(
label="Text prompt",
scale=8
)
option = gr.Dropdown(
label="Inference steps",
choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
value="4 Steps",
interactive=True
)
submit = gr.Button(
scale=1,
variant="primary"
)
img = gr.Image(label="SDXL-Lightning Generated Image")
prompt.submit(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
submit.click(
fn=generate,
inputs=[prompt, option],
outputs=img,
)
gr.Examples(
fn=generate,
examples=[
["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
["A lion in the galaxy, octane render", "2 Steps"],
["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
["A girl smiling", "4 Steps"],
["An astronaut riding a horse", "4 Steps"],
["A fish on a bicycle, colorful art", "4 Steps"],
["A close-up of an Asian lady with sunglasses.", "4 Steps"],
["Rabbit portrait in a forest, fantasy", "4 Steps"],
["A panda swimming", "4 Steps"],
["Man portrait, ethereal", "8 Steps"],
],
inputs=[prompt, option],
outputs=img,
cache_examples=False,
)
gr.HTML(
"<p><small><center>This demo is built together by the community</center></small></p>"
)
demo.queue().launch() |