kadirnar commited on
Commit
35dc227
1 Parent(s): 74d3344

Create stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +137 -0
stable_cascade.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
3
+ import gradio as gr
4
+
5
+
6
+ # Initialize the prior and decoder pipelines
7
+ prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
8
+ prior.enable_xformers_memory_efficient_attention()
9
+
10
+ decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
11
+ decoder.enable_xformers_memory_efficient_attention()
12
+
13
+ def generate_images(
14
+ prompt="a photo of a girl",
15
+ negative_prompt="bad,ugly,deformed",
16
+ height=1024,
17
+ width=1024,
18
+ guidance_scale=4.0,
19
+ prior_inference_steps=20,
20
+ decoder_inference_steps=10
21
+ ):
22
+ """
23
+ Generates images based on a given prompt using Stable Diffusion models on CUDA device.
24
+ Parameters:
25
+ - prompt (str): The prompt to generate images for.
26
+ - negative_prompt (str): The negative prompt to guide image generation away from.
27
+ - height (int): The height of the generated images.
28
+ - width (int): The width of the generated images.
29
+ - guidance_scale (float): The scale of guidance for the image generation.
30
+ - prior_inference_steps (int): The number of inference steps for the prior model.
31
+ - decoder_inference_steps (int): The number of inference steps for the decoder model.
32
+ Returns:
33
+ - List[PIL.Image]: A list of generated PIL Image objects.
34
+ """
35
+
36
+ # Generate image embeddings using the prior model
37
+ prior_output = prior(
38
+ prompt=prompt,
39
+ height=height,
40
+ width=width,
41
+ negative_prompt=negative_prompt,
42
+ guidance_scale=guidance_scale,
43
+ num_images_per_prompt=1,
44
+ num_inference_steps=prior_inference_steps
45
+ )
46
+
47
+ # Generate images using the decoder model and the embeddings from the prior model
48
+ decoder_output = decoder(
49
+ image_embeddings=prior_output.image_embeddings.half(),
50
+ prompt=prompt,
51
+ negative_prompt=negative_prompt,
52
+ guidance_scale=0.0, # Guidance scale typically set to 0 for decoder as guidance is applied in the prior
53
+ output_type="pil",
54
+ num_inference_steps=decoder_inference_steps
55
+ ).images
56
+
57
+ return decoder_output
58
+
59
+
60
+ def web_demo():
61
+ with gr.Blocks():
62
+ with gr.Row():
63
+ with gr.Column():
64
+ text2image_prompt = gr.Textbox(
65
+ lines=1,
66
+ placeholder="Prompt",
67
+ show_label=False,
68
+ )
69
+
70
+ text2image_negative_prompt = gr.Textbox(
71
+ lines=1,
72
+ placeholder="Negative Prompt",
73
+ show_label=False,
74
+ )
75
+ with gr.Row():
76
+ with gr.Column():
77
+ text2image_height = gr.Slider(
78
+ minimum=128,
79
+ maximum=1280,
80
+ step=32,
81
+ value=512,
82
+ label="Image Height",
83
+ )
84
+
85
+ text2image_width = gr.Slider(
86
+ minimum=128,
87
+ maximum=1280,
88
+ step=32,
89
+ value=512,
90
+ label="Image Width",
91
+ )
92
+ with gr.Row():
93
+ with gr.Column():
94
+ text2image_guidance_scale = gr.Slider(
95
+ minimum=0.1,
96
+ maximum=15,
97
+ step=0.1,
98
+ value=4.0,
99
+ label="Guidance Scale",
100
+ )
101
+ text2image_prior_inference_step = gr.Slider(
102
+ minimum=1,
103
+ maximum=50,
104
+ step=1,
105
+ value=20,
106
+ label="Prior Inference Step",
107
+ )
108
+
109
+ text2image_decoder_inference_step = gr.Slider(
110
+ minimum=1,
111
+ maximum=50,
112
+ step=1,
113
+ value=10,
114
+ label="Decoder Inference Step",
115
+ )
116
+ text2image_predict = gr.Button(value="Generate Image")
117
+
118
+ with gr.Column():
119
+ output_image = gr.Gallery(
120
+ label="Generated images",
121
+ show_label=False,
122
+ elem_id="gallery",
123
+ )
124
+
125
+ text2image_predict.click(
126
+ fn=generate_images,
127
+ inputs=[
128
+ text2image_prompt,
129
+ text2image_negative_prompt,
130
+ text2image_height,
131
+ text2image_width,
132
+ text2image_guidance_scale,
133
+ text2image_prior_inference_step,
134
+ text2image_decoder_inference_step
135
+ ],
136
+ outputs=output_image,
137
+ )