fia commited on
Commit
64d8b6a
1 Parent(s): d5bc43e

CPU Settings

Browse files

Add Width and Height

Fix prompt

Streamlit secret API key

Files changed (2) hide show
  1. app.py +31 -159
  2. unsafe.png +0 -0
app.py CHANGED
@@ -6,41 +6,34 @@ from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
8
  import re
 
9
 
10
  model_id = "CompVis/stable-diffusion-v1-4"
11
- device = "cuda"
12
 
13
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
15
- pipe = pipe.to(device)
16
- #When running locally, you won`t have access to this, so you can remove this part
17
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
18
- word_list = word_list_dataset["train"]['text']
19
 
20
- def infer(prompt, samples, steps, scale, seed):
21
- #When running locally you can also remove this filter
22
- for filter in word_list:
23
- if re.search(rf"\b{filter}\b", prompt):
24
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
25
-
26
- generator = torch.Generator(device=device).manual_seed(seed)
27
-
28
- #If you are running locally with CPU, you can remove the `with autocast("cuda")`
29
- with autocast("cuda"):
30
  images_list = pipe(
31
- [prompt] * samples,
 
 
32
  num_inference_steps=steps,
33
  guidance_scale=scale,
34
- generator=generator,
35
- )
36
- images = []
37
- safe_image = Image.open(r"unsafe.png")
38
- for i, image in enumerate(images_list["sample"]):
39
- if(images_list["nsfw_content_detected"][i]):
40
- images.append(safe_image)
41
- else:
42
- images.append(image)
43
- return images
44
 
45
  css = """
46
  .gradio-container {
@@ -89,18 +82,6 @@ css = """
89
  --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
90
  --tw-ring-opacity: .5;
91
  }
92
- #advanced-btn {
93
- font-size: .7rem !important;
94
- line-height: 19px;
95
- margin-top: 12px;
96
- margin-bottom: 12px;
97
- padding: 2px 8px;
98
- border-radius: 14px !important;
99
- }
100
- #advanced-options {
101
- display: none;
102
- margin-bottom: 20px;
103
- }
104
  .footer {
105
  margin-bottom: 45px;
106
  margin-top: 35px;
@@ -129,44 +110,6 @@ css = """
129
 
130
  block = gr.Blocks(css=css)
131
 
132
- examples = [
133
- [
134
- 'A high tech solarpunk utopia in the Amazon rainforest',
135
- 4,
136
- 45,
137
- 7.5,
138
- 1024,
139
- ],
140
- [
141
- 'A pikachu fine dining with a view to the Eiffel Tower',
142
- 4,
143
- 45,
144
- 7,
145
- 1024,
146
- ],
147
- [
148
- 'A mecha robot in a favela in expressionist style',
149
- 4,
150
- 45,
151
- 7,
152
- 1024,
153
- ],
154
- [
155
- 'an insect robot preparing a delicious meal',
156
- 4,
157
- 45,
158
- 7,
159
- 1024,
160
- ],
161
- [
162
- "A small cabin on top of a snowy mountain in the style of Disney, artstation",
163
- 4,
164
- 45,
165
- 7,
166
- 1024,
167
- ],
168
- ]
169
-
170
  with block:
171
  gr.HTML(
172
  """
@@ -179,54 +122,10 @@ with block:
179
  font-size: 1.75rem;
180
  "
181
  >
182
- <svg
183
- width="0.65em"
184
- height="0.65em"
185
- viewBox="0 0 115 115"
186
- fill="none"
187
- xmlns="http://www.w3.org/2000/svg"
188
- >
189
- <rect width="23" height="23" fill="white"></rect>
190
- <rect y="69" width="23" height="23" fill="white"></rect>
191
- <rect x="23" width="23" height="23" fill="#AEAEAE"></rect>
192
- <rect x="23" y="69" width="23" height="23" fill="#AEAEAE"></rect>
193
- <rect x="46" width="23" height="23" fill="white"></rect>
194
- <rect x="46" y="69" width="23" height="23" fill="white"></rect>
195
- <rect x="69" width="23" height="23" fill="black"></rect>
196
- <rect x="69" y="69" width="23" height="23" fill="black"></rect>
197
- <rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
198
- <rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
199
- <rect x="115" y="46" width="23" height="23" fill="white"></rect>
200
- <rect x="115" y="115" width="23" height="23" fill="white"></rect>
201
- <rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
202
- <rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
203
- <rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
204
- <rect x="92" y="69" width="23" height="23" fill="white"></rect>
205
- <rect x="69" y="46" width="23" height="23" fill="white"></rect>
206
- <rect x="69" y="115" width="23" height="23" fill="white"></rect>
207
- <rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
208
- <rect x="46" y="46" width="23" height="23" fill="black"></rect>
209
- <rect x="46" y="115" width="23" height="23" fill="black"></rect>
210
- <rect x="46" y="69" width="23" height="23" fill="black"></rect>
211
- <rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
212
- <rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
213
- <rect x="23" y="69" width="23" height="23" fill="black"></rect>
214
- </svg>
215
  <h1 style="font-weight: 900; margin-bottom: 7px;">
216
- Stable Diffusion Demo
217
  </h1>
218
  </div>
219
- <p style="margin-bottom: 10px; font-size: 94%">
220
- Stable Diffusion is a state of the art text-to-image model that generates
221
- images from text.<br>For faster generation and forthcoming API
222
- access you can try
223
- <a
224
- href="http://beta.dreamstudio.ai/"
225
- style="text-decoration: underline;"
226
- target="_blank"
227
- >DreamStudio Beta</a
228
- >
229
- </p>
230
  </div>
231
  """
232
  )
@@ -252,51 +151,24 @@ with block:
252
  label="Generated images", show_label=False, elem_id="gallery"
253
  ).style(grid=[2], height="auto")
254
 
255
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
 
 
256
 
257
- with gr.Row(elem_id="advanced-options"):
258
- samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
259
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
260
  scale = gr.Slider(
261
  label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
262
  )
263
  seed = gr.Slider(
264
  label="Seed",
265
- minimum=0,
266
  maximum=2147483647,
267
  step=1,
268
- randomize=True,
269
  )
270
-
271
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, steps, scale, seed], outputs=gallery, cache_examples=True)
272
- ex.dataset.headers = [""]
273
-
274
 
275
- text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
276
- btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
277
- advanced_button.click(
278
- None,
279
- [],
280
- text,
281
- _js="""
282
- () => {
283
- const options = document.querySelector("body > gradio-app").querySelector("#advanced-options");
284
- options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
285
- }""",
286
- )
287
- gr.HTML(
288
- """
289
- <div class="footer">
290
- <p>Model by <a href="https://huggingface.co/CompVis" style="text-decoration: underline;" target="_blank">CompVis</a> and <a href="https://huggingface.co/stabilityai" style="text-decoration: underline;" target="_blank">Stability AI</a> - Gradio Demo by 🤗 Hugging Face
291
- </p>
292
- </div>
293
- <div class="acknowledgments">
294
- <p><h4>LICENSE</h4>
295
- The model is licensed with a <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" style="text-decoration: underline;" target="_blank">CreativeML Open RAIL-M</a> license. The authors claim no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in this license. The license forbids you from sharing any content that violates any laws, produce any harm to a person, disseminate any personal information that would be meant for harm, spread misinformation and target vulnerable groups. For the full list of restrictions please <a href="https://huggingface.co/spaces/CompVis/stable-diffusion-license" target="_blank" style="text-decoration: underline;" target="_blank">read the license</a></p>
296
- <p><h4>Biases and content acknowledgment</h4>
297
- Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography and violence. The model was trained on the <a href="https://laion.ai/blog/laion-5b/" style="text-decoration: underline;" target="_blank">LAION-5B dataset</a>, which scraped non-curated image-text-pairs from the internet (the exception being the removal of illegal content) and is meant for research purposes. You can read more in the <a href="https://huggingface.co/CompVis/stable-diffusion-v1-4" style="text-decoration: underline;" target="_blank">model card</a></p>
298
- </div>
299
- """
300
- )
301
 
302
- block.queue(max_size=25).launch()
 
6
  from datasets import load_dataset
7
  from PIL import Image
8
  import re
9
+ import streamlit as st
10
 
11
  model_id = "CompVis/stable-diffusion-v1-4"
12
+ device = "cpu"
13
 
14
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
15
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=st.secrets["AUTH_KEY"], torch_dtype=torch.float32)
16
+ def dummy(images, **kwargs): return images, False
17
+ pipe.safety_checker = dummy
 
 
18
 
19
+ def infer(prompt, width, height, steps, scale, seed):
20
+ if seed == -1:
 
 
 
 
 
 
 
 
21
  images_list = pipe(
22
+ [prompt],
23
+ height=height,
24
+ width=width,
25
  num_inference_steps=steps,
26
  guidance_scale=scale,
27
+ generator=torch.Generator(device=device).manual_seed(seed))
28
+ else:
29
+ images_list = pipe(
30
+ [prompt],
31
+ height=height,
32
+ width=width,
33
+ num_inference_steps=steps,
34
+ guidance_scale=scale)
35
+
36
+ return images_list["sample"]
37
 
38
  css = """
39
  .gradio-container {
 
82
  --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
83
  --tw-ring-opacity: .5;
84
  }
 
 
 
 
 
 
 
 
 
 
 
 
85
  .footer {
86
  margin-bottom: 45px;
87
  margin-top: 35px;
 
110
 
111
  block = gr.Blocks(css=css)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  with block:
114
  gr.HTML(
115
  """
 
122
  font-size: 1.75rem;
123
  "
124
  >
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  <h1 style="font-weight: 900; margin-bottom: 7px;">
126
+ Stable Diffusion CPU
127
  </h1>
128
  </div>
 
 
 
 
 
 
 
 
 
 
 
129
  </div>
130
  """
131
  )
 
151
  label="Generated images", show_label=False, elem_id="gallery"
152
  ).style(grid=[2], height="auto")
153
 
154
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
155
+ width = gr.Slider(label="Width", minimum=32, maximum=1024, value=512, step=8)
156
+ height = gr.Slider(label="Height", minimum=32, maximum=1024, value=512, step=8)
157
 
158
+ with gr.Row():
159
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=30, step=1)
 
160
  scale = gr.Slider(
161
  label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
162
  )
163
  seed = gr.Slider(
164
  label="Seed",
165
+ minimum=-1,
166
  maximum=2147483647,
167
  step=1,
168
+ value=-1,
169
  )
 
 
 
 
170
 
171
+ text.submit(infer, inputs=[text, width, height, steps, scale, seed], outputs=gallery)
172
+ btn.click(infer, inputs=[text, width, height, steps, scale, seed], outputs=gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ block.queue(max_size=10).launch()
unsafe.png DELETED
Binary file (29.6 kB)