PeterL1n commited on
Commit
8425644
·
verified ·
1 Parent(s): f1e3c7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import spaces
4
- from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
 
@@ -9,26 +9,25 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
9
  base = "stabilityai/stable-diffusion-xl-base-1.0"
10
  repo = "ByteDance/SDXL-Lightning"
11
  opts = {
12
- "1 Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
13
- "2 Steps" : ["sdxl_lightning_2step_unet.safetensors", 2],
14
- "4 Steps" : ["sdxl_lightning_4step_unet.safetensors", 4],
15
- "8 Steps" : ["sdxl_lightning_8step_unet.safetensors", 8],
16
  }
17
 
 
 
 
18
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
19
- last_step = None
20
 
21
- # Function
22
  @spaces.GPU(enable_queue=True)
23
  def generate_image(prompt, option):
24
  ckpt, step = opts[option]
25
- if last_step != step:
26
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
27
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
28
- last_step = step
29
- image = pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
30
- return image
31
-
32
 
33
  with gr.Blocks() as demo:
34
  gr.HTML("<h1><center>SDXL-Lightning</center></h1>")
 
1
  import gradio as gr
2
  import torch
3
  import spaces
4
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
 
 
9
  base = "stabilityai/stable-diffusion-xl-base-1.0"
10
  repo = "ByteDance/SDXL-Lightning"
11
  opts = {
12
+ "1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
13
+ "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
14
+ "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
15
+ "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
16
  }
17
 
18
+ step_loaded = 4
19
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
20
+ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
21
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
 
22
 
 
23
  @spaces.GPU(enable_queue=True)
24
  def generate_image(prompt, option):
25
  ckpt, step = opts[option]
26
+ if step_loaded != step:
27
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
28
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
29
+ step_loaded = step
30
+ return pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
 
 
31
 
32
  with gr.Blocks() as demo:
33
  gr.HTML("<h1><center>SDXL-Lightning</center></h1>")