kadirnar commited on
Commit
80eae63
1 Parent(s): 7e65a14

Update stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +17 -7
stable_cascade.py CHANGED
@@ -7,14 +7,14 @@ import gradio as gr
7
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16)
8
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.bfloat16)
9
 
10
- prior_pipeline.enable_model_cpu_offload()
11
- decoder_pipeline.enable_model_cpu_offload()
12
 
13
- prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
14
- decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
15
 
16
- prior_pipeline.to("cuda")
17
- decoder_pipeline.to("cuda")
18
 
19
  def generate_images(
20
  prompt="a photo of a girl",
@@ -22,6 +22,7 @@ def generate_images(
22
  height=1024,
23
  width=1024,
24
  guidance_scale=4.0,
 
25
  prior_inference_steps=20,
26
  decoder_inference_steps=10
27
  ):
@@ -46,7 +47,7 @@ def generate_images(
46
  width=width,
47
  negative_prompt=negative_prompt,
48
  guidance_scale=guidance_scale,
49
- num_images_per_prompt=1,
50
  num_inference_steps=prior_inference_steps
51
  )
52
 
@@ -80,6 +81,14 @@ def web_demo():
80
  )
81
  with gr.Row():
82
  with gr.Column():
 
 
 
 
 
 
 
 
83
  text2image_height = gr.Slider(
84
  minimum=128,
85
  maximum=1280,
@@ -136,6 +145,7 @@ def web_demo():
136
  text2image_height,
137
  text2image_width,
138
  text2image_guidance_scale,
 
139
  text2image_prior_inference_step,
140
  text2image_decoder_inference_step
141
  ],
 
7
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16)
8
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.bfloat16)
9
 
10
+ prior.enable_model_cpu_offload()
11
+ decoder.enable_model_cpu_offload()
12
 
13
+ prior.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
14
+ decoder.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
15
 
16
+ prior.to("cuda")
17
+ decoder.to("cuda")
18
 
19
  def generate_images(
20
  prompt="a photo of a girl",
 
22
  height=1024,
23
  width=1024,
24
  guidance_scale=4.0,
25
+ num_images_per_prompt=1,
26
  prior_inference_steps=20,
27
  decoder_inference_steps=10
28
  ):
 
47
  width=width,
48
  negative_prompt=negative_prompt,
49
  guidance_scale=guidance_scale,
50
+ num_images_per_prompt=num_images_per_prompt,
51
  num_inference_steps=prior_inference_steps
52
  )
53
 
 
81
  )
82
  with gr.Row():
83
  with gr.Column():
84
+ text2image_num_images_per_prompt = gr.Slider(
85
+ minimum=1,
86
+ maximum=4,
87
+ step=1,
88
+ value=1,
89
+ label="Number Image",
90
+ )
91
+
92
  text2image_height = gr.Slider(
93
  minimum=128,
94
  maximum=1280,
 
145
  text2image_height,
146
  text2image_width,
147
  text2image_guidance_scale,
148
+ text2image_num_images_per_prompt,
149
  text2image_prior_inference_step,
150
  text2image_decoder_inference_step
151
  ],