Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import torch
|
|
10 |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
11 |
from typing import Tuple
|
12 |
|
|
|
13 |
bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
|
14 |
bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
|
15 |
default_negative = os.getenv("default_negative", "")
|
@@ -23,7 +24,7 @@ def check_text(prompt, negative=""):
|
|
23 |
return True
|
24 |
return False
|
25 |
|
26 |
-
# Quality/Style
|
27 |
style_list = [
|
28 |
{
|
29 |
"name": "3840 x 2160",
|
@@ -47,7 +48,7 @@ style_list = [
|
|
47 |
},
|
48 |
]
|
49 |
|
50 |
-
# Collage styles
|
51 |
collage_style_list = [
|
52 |
{
|
53 |
"name": "Hi-Res",
|
@@ -176,7 +177,7 @@ collage_style_list = [
|
|
176 |
},
|
177 |
]
|
178 |
|
179 |
-
# Filters
|
180 |
filters = {
|
181 |
"Vivid": {
|
182 |
"prompt": "extra vivid {prompt}",
|
@@ -243,10 +244,10 @@ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
|
243 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
244 |
|
245 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
246 |
-
# Set dtype based on device: half
|
247 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
248 |
|
249 |
-
# Load
|
250 |
if torch.cuda.is_available():
|
251 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
252 |
"SG161222/RealVisXL_V5.0_Lightning",
|
@@ -254,27 +255,60 @@ if torch.cuda.is_available():
|
|
254 |
use_safetensors=True,
|
255 |
add_watermarker=False
|
256 |
).to(device)
|
257 |
-
# Ensure
|
258 |
pipe.text_encoder = pipe.text_encoder.half()
|
259 |
|
260 |
if ENABLE_CPU_OFFLOAD:
|
261 |
pipe.enable_model_cpu_offload()
|
262 |
else:
|
263 |
pipe.to(device)
|
264 |
-
print("Loaded on Device!")
|
265 |
|
266 |
if USE_TORCH_COMPILE:
|
267 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
268 |
-
print("Model Compiled!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
else:
|
270 |
-
# On CPU
|
271 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
272 |
"SG161222/RealVisXL_V5.0_Lightning",
|
273 |
torch_dtype=dtype,
|
274 |
use_safetensors=True,
|
275 |
add_watermarker=False
|
276 |
).to(device)
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
def save_image(img, path):
|
280 |
img.save(path)
|
@@ -298,6 +332,7 @@ def generate(
|
|
298 |
height: int = 1024,
|
299 |
guidance_scale: float = 3,
|
300 |
randomize_seed: bool = False,
|
|
|
301 |
use_resolution_binning: bool = True,
|
302 |
progress=gr.Progress(track_tqdm=True),
|
303 |
):
|
@@ -317,7 +352,7 @@ def generate(
|
|
317 |
if not use_negative_prompt:
|
318 |
negative_prompt = ""
|
319 |
negative_prompt += default_negative
|
320 |
-
|
321 |
grid_sizes = {
|
322 |
"2x1": (2, 1),
|
323 |
"1x2": (1, 2),
|
@@ -343,11 +378,14 @@ def generate(
|
|
343 |
"output_type": "pil",
|
344 |
}
|
345 |
|
346 |
-
|
347 |
-
|
|
|
|
|
|
|
|
|
348 |
|
349 |
grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
|
350 |
-
|
351 |
for i, img in enumerate(images[:num_images]):
|
352 |
grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
|
353 |
|
@@ -385,15 +423,20 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
385 |
placeholder="Enter your prompt",
|
386 |
container=False,
|
387 |
)
|
388 |
-
run_button = gr.Button("Generate
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
with gr.Row(visible=True):
|
391 |
grid_size_selection = gr.Dropdown(
|
392 |
choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
|
393 |
value="1x1",
|
394 |
label="Grid Size"
|
395 |
)
|
396 |
-
|
397 |
with gr.Row(visible=True):
|
398 |
filter_selection = gr.Dropdown(
|
399 |
show_label=True,
|
@@ -403,7 +446,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
403 |
value=DEFAULT_FILTER_NAME,
|
404 |
label="Filter Type",
|
405 |
)
|
406 |
-
|
407 |
with gr.Row(visible=True):
|
408 |
collage_style_selection = gr.Dropdown(
|
409 |
show_label=True,
|
@@ -413,7 +455,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
413 |
value=DEFAULT_COLLAGE_STYLE_NAME,
|
414 |
label="Collage Template + Duotone Canvas",
|
415 |
)
|
416 |
-
|
417 |
with gr.Row(visible=True):
|
418 |
style_selection = gr.Dropdown(
|
419 |
show_label=True,
|
@@ -423,7 +464,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
423 |
value=DEFAULT_STYLE_NAME,
|
424 |
label="Quality Style",
|
425 |
)
|
426 |
-
|
427 |
with gr.Accordion("Advanced options", open=False):
|
428 |
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
|
429 |
negative_prompt = gr.Text(
|
@@ -458,7 +498,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
458 |
visible=True
|
459 |
)
|
460 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
461 |
-
|
462 |
with gr.Row(visible=True):
|
463 |
width = gr.Slider(
|
464 |
label="Width",
|
@@ -474,7 +513,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
474 |
step=64,
|
475 |
value=1024,
|
476 |
)
|
477 |
-
|
478 |
with gr.Row():
|
479 |
guidance_scale = gr.Slider(
|
480 |
label="Guidance Scale",
|
@@ -483,10 +521,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
483 |
step=0.1,
|
484 |
value=6,
|
485 |
)
|
486 |
-
|
487 |
with gr.Column(scale=2):
|
488 |
result = gr.Gallery(label="Result", columns=1, show_label=False)
|
489 |
-
|
490 |
gr.Examples(
|
491 |
examples=examples,
|
492 |
inputs=prompt,
|
@@ -494,14 +530,12 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
494 |
fn=generate,
|
495 |
cache_examples=CACHE_EXAMPLES,
|
496 |
)
|
497 |
-
|
498 |
use_negative_prompt.change(
|
499 |
fn=lambda x: gr.update(visible=x),
|
500 |
inputs=use_negative_prompt,
|
501 |
outputs=negative_prompt,
|
502 |
api_name=False,
|
503 |
)
|
504 |
-
|
505 |
gr.on(
|
506 |
triggers=[
|
507 |
prompt.submit,
|
@@ -522,6 +556,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
|
|
522 |
height,
|
523 |
guidance_scale,
|
524 |
randomize_seed,
|
|
|
525 |
],
|
526 |
outputs=[result, seed],
|
527 |
api_name="run",
|
|
|
10 |
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
11 |
from typing import Tuple
|
12 |
|
13 |
+
# Load restricted words
|
14 |
bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
|
15 |
bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
|
16 |
default_negative = os.getenv("default_negative", "")
|
|
|
24 |
return True
|
25 |
return False
|
26 |
|
27 |
+
# Quality/Style--------------------------------------------------------------------
|
28 |
style_list = [
|
29 |
{
|
30 |
"name": "3840 x 2160",
|
|
|
48 |
},
|
49 |
]
|
50 |
|
51 |
+
# Collage styles--------------------------------------------------------------------
|
52 |
collage_style_list = [
|
53 |
{
|
54 |
"name": "Hi-Res",
|
|
|
177 |
},
|
178 |
]
|
179 |
|
180 |
+
# Filters--------------------------------------------------------------------
|
181 |
filters = {
|
182 |
"Vivid": {
|
183 |
"prompt": "extra vivid {prompt}",
|
|
|
244 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
245 |
|
246 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
247 |
+
# Set dtype based on device: half for CUDA, float32 for CPU
|
248 |
dtype = torch.float16 if device.type == "cuda" else torch.float32
|
249 |
|
250 |
+
# Load primary model (RealVisXL_V5.0_Lightning)
|
251 |
if torch.cuda.is_available():
|
252 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
253 |
"SG161222/RealVisXL_V5.0_Lightning",
|
|
|
255 |
use_safetensors=True,
|
256 |
add_watermarker=False
|
257 |
).to(device)
|
258 |
+
# Ensure text encoder uses half precision on GPU
|
259 |
pipe.text_encoder = pipe.text_encoder.half()
|
260 |
|
261 |
if ENABLE_CPU_OFFLOAD:
|
262 |
pipe.enable_model_cpu_offload()
|
263 |
else:
|
264 |
pipe.to(device)
|
265 |
+
print("Loaded RealVisXL_V5.0_Lightning on Device!")
|
266 |
|
267 |
if USE_TORCH_COMPILE:
|
268 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
269 |
+
print("Model RealVisXL_V5.0_Lightning Compiled!")
|
270 |
+
|
271 |
+
# Load new model (RealVisXL_V4.0)
|
272 |
+
pipe2 = StableDiffusionXLPipeline.from_pretrained(
|
273 |
+
"SG161222/RealVisXL_V4.0",
|
274 |
+
torch_dtype=dtype,
|
275 |
+
use_safetensors=True,
|
276 |
+
add_watermarker=False,
|
277 |
+
).to(device)
|
278 |
+
pipe2.text_encoder = pipe2.text_encoder.half()
|
279 |
+
|
280 |
+
if ENABLE_CPU_OFFLOAD:
|
281 |
+
pipe2.enable_model_cpu_offload()
|
282 |
+
else:
|
283 |
+
pipe2.to(device)
|
284 |
+
print("Loaded RealVisXL_V4.0 on Device!")
|
285 |
+
|
286 |
+
if USE_TORCH_COMPILE:
|
287 |
+
pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
|
288 |
+
print("Model RealVisXL_V4.0 Compiled!")
|
289 |
else:
|
290 |
+
# On CPU load both models in float32
|
291 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
292 |
"SG161222/RealVisXL_V5.0_Lightning",
|
293 |
torch_dtype=dtype,
|
294 |
use_safetensors=True,
|
295 |
add_watermarker=False
|
296 |
).to(device)
|
297 |
+
pipe2 = StableDiffusionXLPipeline.from_pretrained(
|
298 |
+
"SG161222/RealVisXL_V4.0",
|
299 |
+
torch_dtype=dtype,
|
300 |
+
use_safetensors=True,
|
301 |
+
add_watermarker=False,
|
302 |
+
).to(device)
|
303 |
+
print("Model Loaded !!")
|
304 |
+
|
305 |
+
# A dictionary to easily choose the model based on selection.
|
306 |
+
DEFAULT_MODEL = "RealVisXL_V5.0_Lightning"
|
307 |
+
MODEL_CHOICES = [DEFAULT_MODEL, "RealVisXL_V4.0"]
|
308 |
+
models = {
|
309 |
+
"RealVisXL_V5.0_Lightning": pipe,
|
310 |
+
"RealVisXL_V4.0": pipe2
|
311 |
+
}
|
312 |
|
313 |
def save_image(img, path):
|
314 |
img.save(path)
|
|
|
332 |
height: int = 1024,
|
333 |
guidance_scale: float = 3,
|
334 |
randomize_seed: bool = False,
|
335 |
+
model_choice: str = DEFAULT_MODEL,
|
336 |
use_resolution_binning: bool = True,
|
337 |
progress=gr.Progress(track_tqdm=True),
|
338 |
):
|
|
|
352 |
if not use_negative_prompt:
|
353 |
negative_prompt = ""
|
354 |
negative_prompt += default_negative
|
355 |
+
|
356 |
grid_sizes = {
|
357 |
"2x1": (2, 1),
|
358 |
"1x2": (1, 2),
|
|
|
378 |
"output_type": "pil",
|
379 |
}
|
380 |
|
381 |
+
if device.type == "cuda":
|
382 |
+
torch.cuda.empty_cache()
|
383 |
+
|
384 |
+
# Choose pipeline based on user selection
|
385 |
+
selected_pipe = models.get(model_choice, pipe)
|
386 |
+
images = selected_pipe(**options).images
|
387 |
|
388 |
grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
|
|
|
389 |
for i, img in enumerate(images[:num_images]):
|
390 |
grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
|
391 |
|
|
|
423 |
placeholder="Enter your prompt",
|
424 |
container=False,
|
425 |
)
|
426 |
+
run_button = gr.Button("Generate as (1024 x 1024)🍺", scale=0, elem_classes="submit-btn")
|
427 |
+
|
428 |
+
with gr.Row(visible=True):
|
429 |
+
model_selection = gr.Dropdown(
|
430 |
+
choices=MODEL_CHOICES,
|
431 |
+
value=DEFAULT_MODEL,
|
432 |
+
label="Model Selection",
|
433 |
+
)
|
434 |
with gr.Row(visible=True):
|
435 |
grid_size_selection = gr.Dropdown(
|
436 |
choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
|
437 |
value="1x1",
|
438 |
label="Grid Size"
|
439 |
)
|
|
|
440 |
with gr.Row(visible=True):
|
441 |
filter_selection = gr.Dropdown(
|
442 |
show_label=True,
|
|
|
446 |
value=DEFAULT_FILTER_NAME,
|
447 |
label="Filter Type",
|
448 |
)
|
|
|
449 |
with gr.Row(visible=True):
|
450 |
collage_style_selection = gr.Dropdown(
|
451 |
show_label=True,
|
|
|
455 |
value=DEFAULT_COLLAGE_STYLE_NAME,
|
456 |
label="Collage Template + Duotone Canvas",
|
457 |
)
|
|
|
458 |
with gr.Row(visible=True):
|
459 |
style_selection = gr.Dropdown(
|
460 |
show_label=True,
|
|
|
464 |
value=DEFAULT_STYLE_NAME,
|
465 |
label="Quality Style",
|
466 |
)
|
|
|
467 |
with gr.Accordion("Advanced options", open=False):
|
468 |
use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
|
469 |
negative_prompt = gr.Text(
|
|
|
498 |
visible=True
|
499 |
)
|
500 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
501 |
with gr.Row(visible=True):
|
502 |
width = gr.Slider(
|
503 |
label="Width",
|
|
|
513 |
step=64,
|
514 |
value=1024,
|
515 |
)
|
|
|
516 |
with gr.Row():
|
517 |
guidance_scale = gr.Slider(
|
518 |
label="Guidance Scale",
|
|
|
521 |
step=0.1,
|
522 |
value=6,
|
523 |
)
|
|
|
524 |
with gr.Column(scale=2):
|
525 |
result = gr.Gallery(label="Result", columns=1, show_label=False)
|
|
|
526 |
gr.Examples(
|
527 |
examples=examples,
|
528 |
inputs=prompt,
|
|
|
530 |
fn=generate,
|
531 |
cache_examples=CACHE_EXAMPLES,
|
532 |
)
|
|
|
533 |
use_negative_prompt.change(
|
534 |
fn=lambda x: gr.update(visible=x),
|
535 |
inputs=use_negative_prompt,
|
536 |
outputs=negative_prompt,
|
537 |
api_name=False,
|
538 |
)
|
|
|
539 |
gr.on(
|
540 |
triggers=[
|
541 |
prompt.submit,
|
|
|
556 |
height,
|
557 |
guidance_scale,
|
558 |
randomize_seed,
|
559 |
+
model_selection,
|
560 |
],
|
561 |
outputs=[result, seed],
|
562 |
api_name="run",
|