PeterL1n commited on
Commit
7c672bb
·
verified ·
1 Parent(s): 83eb2ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -19,22 +19,24 @@ opts = {
19
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
20
  }
21
 
22
- step_loaded = None
23
- # unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
24
- # unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
25
- pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=dtype, variant="fp16").to(device)
26
 
27
  @spaces.GPU(enable_queue=True)
28
  def generate_image(prompt, option):
29
  global step_loaded
 
30
  ckpt, step = opts[option]
31
  if step != step_loaded:
 
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
33
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
34
  step_loaded = step
35
  return pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
36
 
37
- with gr.Blocks() as demo:
38
  gr.HTML(
39
  "<h1><center>SDXL-Lightning</center></h1>" +
40
  "<p><center>Lightning-fast text-to-image generation</center></p>" +
@@ -57,7 +59,7 @@ with gr.Blocks() as demo:
57
  variant="primary"
58
  )
59
 
60
- img = gr.Image(label="SDXL-Lightening Generated Image")
61
 
62
  prompt.submit(
63
  fn=generate_image,
 
19
  "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
20
  }
21
 
22
+ step_loaded = 4
23
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
24
+ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
26
 
27
  @spaces.GPU(enable_queue=True)
28
  def generate_image(prompt, option):
29
  global step_loaded
30
+ print(prompt, option)
31
  ckpt, step = opts[option]
32
  if step != step_loaded:
33
+ print(f"Switching checkpoint from {step_loaded} to {step}")
34
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
35
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
36
  step_loaded = step
37
  return pipe(prompt, num_inference_steps=step, guidance_scale=0).images[0]
38
 
39
+ with gr.Blocks(css="style.css") as demo:
40
  gr.HTML(
41
  "<h1><center>SDXL-Lightning</center></h1>" +
42
  "<p><center>Lightning-fast text-to-image generation</center></p>" +
 
59
  variant="primary"
60
  )
61
 
62
+ img = gr.Image(label="SDXL-Lighting Generated Image")
63
 
64
  prompt.submit(
65
  fn=generate_image,