prithivMLmods commited on
Commit
2a821e6
·
verified ·
1 Parent(s): 8e7a095

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -26
app.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
11
  from typing import Tuple
12
 
 
13
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
14
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
15
  default_negative = os.getenv("default_negative", "")
@@ -23,7 +24,7 @@ def check_text(prompt, negative=""):
23
  return True
24
  return False
25
 
26
- # Quality/Style
27
  style_list = [
28
  {
29
  "name": "3840 x 2160",
@@ -47,7 +48,7 @@ style_list = [
47
  },
48
  ]
49
 
50
- # Collage styles
51
  collage_style_list = [
52
  {
53
  "name": "Hi-Res",
@@ -176,7 +177,7 @@ collage_style_list = [
176
  },
177
  ]
178
 
179
- # Filters
180
  filters = {
181
  "Vivid": {
182
  "prompt": "extra vivid {prompt}",
@@ -243,10 +244,10 @@ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
243
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
244
 
245
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
246
- # Set dtype based on device: half precision for CUDA, full precision for CPU
247
  dtype = torch.float16 if device.type == "cuda" else torch.float32
248
 
249
- # Load the pipeline
250
  if torch.cuda.is_available():
251
  pipe = StableDiffusionXLPipeline.from_pretrained(
252
  "SG161222/RealVisXL_V5.0_Lightning",
@@ -254,27 +255,60 @@ if torch.cuda.is_available():
254
  use_safetensors=True,
255
  add_watermarker=False
256
  ).to(device)
257
- # Ensure the text encoder is in half precision to match the rest of the model
258
  pipe.text_encoder = pipe.text_encoder.half()
259
 
260
  if ENABLE_CPU_OFFLOAD:
261
  pipe.enable_model_cpu_offload()
262
  else:
263
  pipe.to(device)
264
- print("Loaded on Device!")
265
 
266
  if USE_TORCH_COMPILE:
267
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
268
- print("Model Compiled!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  else:
270
- # On CPU, load with float32
271
  pipe = StableDiffusionXLPipeline.from_pretrained(
272
  "SG161222/RealVisXL_V5.0_Lightning",
273
  torch_dtype=dtype,
274
  use_safetensors=True,
275
  add_watermarker=False
276
  ).to(device)
277
- print("Running on CPU; model loaded in float32.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def save_image(img, path):
280
  img.save(path)
@@ -298,6 +332,7 @@ def generate(
298
  height: int = 1024,
299
  guidance_scale: float = 3,
300
  randomize_seed: bool = False,
 
301
  use_resolution_binning: bool = True,
302
  progress=gr.Progress(track_tqdm=True),
303
  ):
@@ -317,7 +352,7 @@ def generate(
317
  if not use_negative_prompt:
318
  negative_prompt = ""
319
  negative_prompt += default_negative
320
-
321
  grid_sizes = {
322
  "2x1": (2, 1),
323
  "1x2": (1, 2),
@@ -343,11 +378,14 @@ def generate(
343
  "output_type": "pil",
344
  }
345
 
346
- torch.cuda.empty_cache() # Clear GPU memory if available
347
- images = pipe(**options).images
 
 
 
 
348
 
349
  grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
350
-
351
  for i, img in enumerate(images[:num_images]):
352
  grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
353
 
@@ -385,15 +423,20 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
385
  placeholder="Enter your prompt",
386
  container=False,
387
  )
388
- run_button = gr.Button("Generate as (1024 x 1024)🍺", scale=0, elem_classes="submit-btn")
389
-
 
 
 
 
 
 
390
  with gr.Row(visible=True):
391
  grid_size_selection = gr.Dropdown(
392
  choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
393
  value="1x1",
394
  label="Grid Size"
395
  )
396
-
397
  with gr.Row(visible=True):
398
  filter_selection = gr.Dropdown(
399
  show_label=True,
@@ -403,7 +446,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
403
  value=DEFAULT_FILTER_NAME,
404
  label="Filter Type",
405
  )
406
-
407
  with gr.Row(visible=True):
408
  collage_style_selection = gr.Dropdown(
409
  show_label=True,
@@ -413,7 +455,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
413
  value=DEFAULT_COLLAGE_STYLE_NAME,
414
  label="Collage Template + Duotone Canvas",
415
  )
416
-
417
  with gr.Row(visible=True):
418
  style_selection = gr.Dropdown(
419
  show_label=True,
@@ -423,7 +464,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
423
  value=DEFAULT_STYLE_NAME,
424
  label="Quality Style",
425
  )
426
-
427
  with gr.Accordion("Advanced options", open=False):
428
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
429
  negative_prompt = gr.Text(
@@ -458,7 +498,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
458
  visible=True
459
  )
460
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
461
-
462
  with gr.Row(visible=True):
463
  width = gr.Slider(
464
  label="Width",
@@ -474,7 +513,6 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
474
  step=64,
475
  value=1024,
476
  )
477
-
478
  with gr.Row():
479
  guidance_scale = gr.Slider(
480
  label="Guidance Scale",
@@ -483,10 +521,8 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
483
  step=0.1,
484
  value=6,
485
  )
486
-
487
  with gr.Column(scale=2):
488
  result = gr.Gallery(label="Result", columns=1, show_label=False)
489
-
490
  gr.Examples(
491
  examples=examples,
492
  inputs=prompt,
@@ -494,14 +530,12 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
494
  fn=generate,
495
  cache_examples=CACHE_EXAMPLES,
496
  )
497
-
498
  use_negative_prompt.change(
499
  fn=lambda x: gr.update(visible=x),
500
  inputs=use_negative_prompt,
501
  outputs=negative_prompt,
502
  api_name=False,
503
  )
504
-
505
  gr.on(
506
  triggers=[
507
  prompt.submit,
@@ -522,6 +556,7 @@ with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
522
  height,
523
  guidance_scale,
524
  randomize_seed,
 
525
  ],
526
  outputs=[result, seed],
527
  api_name="run",
 
10
  from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
11
  from typing import Tuple
12
 
13
+ # Load restricted words
14
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
15
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
16
  default_negative = os.getenv("default_negative", "")
 
24
  return True
25
  return False
26
 
27
+ # Quality/Style--------------------------------------------------------------------
28
  style_list = [
29
  {
30
  "name": "3840 x 2160",
 
48
  },
49
  ]
50
 
51
+ # Collage styles--------------------------------------------------------------------
52
  collage_style_list = [
53
  {
54
  "name": "Hi-Res",
 
177
  },
178
  ]
179
 
180
+ # Filters--------------------------------------------------------------------
181
  filters = {
182
  "Vivid": {
183
  "prompt": "extra vivid {prompt}",
 
244
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
245
 
246
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
247
+ # Set dtype based on device: half for CUDA, float32 for CPU
248
  dtype = torch.float16 if device.type == "cuda" else torch.float32
249
 
250
+ # Load primary model (RealVisXL_V5.0_Lightning)
251
  if torch.cuda.is_available():
252
  pipe = StableDiffusionXLPipeline.from_pretrained(
253
  "SG161222/RealVisXL_V5.0_Lightning",
 
255
  use_safetensors=True,
256
  add_watermarker=False
257
  ).to(device)
258
+ # Ensure text encoder uses half precision on GPU
259
  pipe.text_encoder = pipe.text_encoder.half()
260
 
261
  if ENABLE_CPU_OFFLOAD:
262
  pipe.enable_model_cpu_offload()
263
  else:
264
  pipe.to(device)
265
+ print("Loaded RealVisXL_V5.0_Lightning on Device!")
266
 
267
  if USE_TORCH_COMPILE:
268
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
269
+ print("Model RealVisXL_V5.0_Lightning Compiled!")
270
+
271
+ # Load new model (RealVisXL_V4.0)
272
+ pipe2 = StableDiffusionXLPipeline.from_pretrained(
273
+ "SG161222/RealVisXL_V4.0",
274
+ torch_dtype=dtype,
275
+ use_safetensors=True,
276
+ add_watermarker=False,
277
+ ).to(device)
278
+ pipe2.text_encoder = pipe2.text_encoder.half()
279
+
280
+ if ENABLE_CPU_OFFLOAD:
281
+ pipe2.enable_model_cpu_offload()
282
+ else:
283
+ pipe2.to(device)
284
+ print("Loaded RealVisXL_V4.0 on Device!")
285
+
286
+ if USE_TORCH_COMPILE:
287
+ pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
288
+ print("Model RealVisXL_V4.0 Compiled!")
289
  else:
290
+ # On CPU load both models in float32
291
  pipe = StableDiffusionXLPipeline.from_pretrained(
292
  "SG161222/RealVisXL_V5.0_Lightning",
293
  torch_dtype=dtype,
294
  use_safetensors=True,
295
  add_watermarker=False
296
  ).to(device)
297
+ pipe2 = StableDiffusionXLPipeline.from_pretrained(
298
+ "SG161222/RealVisXL_V4.0",
299
+ torch_dtype=dtype,
300
+ use_safetensors=True,
301
+ add_watermarker=False,
302
+ ).to(device)
303
+ print("Model Loaded !!")
304
+
305
+ # A dictionary to easily choose the model based on selection.
306
+ DEFAULT_MODEL = "RealVisXL_V5.0_Lightning"
307
+ MODEL_CHOICES = [DEFAULT_MODEL, "RealVisXL_V4.0"]
308
+ models = {
309
+ "RealVisXL_V5.0_Lightning": pipe,
310
+ "RealVisXL_V4.0": pipe2
311
+ }
312
 
313
  def save_image(img, path):
314
  img.save(path)
 
332
  height: int = 1024,
333
  guidance_scale: float = 3,
334
  randomize_seed: bool = False,
335
+ model_choice: str = DEFAULT_MODEL,
336
  use_resolution_binning: bool = True,
337
  progress=gr.Progress(track_tqdm=True),
338
  ):
 
352
  if not use_negative_prompt:
353
  negative_prompt = ""
354
  negative_prompt += default_negative
355
+
356
  grid_sizes = {
357
  "2x1": (2, 1),
358
  "1x2": (1, 2),
 
378
  "output_type": "pil",
379
  }
380
 
381
+ if device.type == "cuda":
382
+ torch.cuda.empty_cache()
383
+
384
+ # Choose pipeline based on user selection
385
+ selected_pipe = models.get(model_choice, pipe)
386
+ images = selected_pipe(**options).images
387
 
388
  grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
 
389
  for i, img in enumerate(images[:num_images]):
390
  grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
391
 
 
423
  placeholder="Enter your prompt",
424
  container=False,
425
  )
426
+ run_button = gr.Button("Generate as (1024 x 1024)🍺", scale=0, elem_classes="submit-btn")
427
+
428
+ with gr.Row(visible=True):
429
+ model_selection = gr.Dropdown(
430
+ choices=MODEL_CHOICES,
431
+ value=DEFAULT_MODEL,
432
+ label="Model Selection",
433
+ )
434
  with gr.Row(visible=True):
435
  grid_size_selection = gr.Dropdown(
436
  choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
437
  value="1x1",
438
  label="Grid Size"
439
  )
 
440
  with gr.Row(visible=True):
441
  filter_selection = gr.Dropdown(
442
  show_label=True,
 
446
  value=DEFAULT_FILTER_NAME,
447
  label="Filter Type",
448
  )
 
449
  with gr.Row(visible=True):
450
  collage_style_selection = gr.Dropdown(
451
  show_label=True,
 
455
  value=DEFAULT_COLLAGE_STYLE_NAME,
456
  label="Collage Template + Duotone Canvas",
457
  )
 
458
  with gr.Row(visible=True):
459
  style_selection = gr.Dropdown(
460
  show_label=True,
 
464
  value=DEFAULT_STYLE_NAME,
465
  label="Quality Style",
466
  )
 
467
  with gr.Accordion("Advanced options", open=False):
468
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
469
  negative_prompt = gr.Text(
 
498
  visible=True
499
  )
500
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
501
  with gr.Row(visible=True):
502
  width = gr.Slider(
503
  label="Width",
 
513
  step=64,
514
  value=1024,
515
  )
 
516
  with gr.Row():
517
  guidance_scale = gr.Slider(
518
  label="Guidance Scale",
 
521
  step=0.1,
522
  value=6,
523
  )
 
524
  with gr.Column(scale=2):
525
  result = gr.Gallery(label="Result", columns=1, show_label=False)
 
526
  gr.Examples(
527
  examples=examples,
528
  inputs=prompt,
 
530
  fn=generate,
531
  cache_examples=CACHE_EXAMPLES,
532
  )
 
533
  use_negative_prompt.change(
534
  fn=lambda x: gr.update(visible=x),
535
  inputs=use_negative_prompt,
536
  outputs=negative_prompt,
537
  api_name=False,
538
  )
 
539
  gr.on(
540
  triggers=[
541
  prompt.submit,
 
556
  height,
557
  guidance_scale,
558
  randomize_seed,
559
+ model_selection,
560
  ],
561
  outputs=[result, seed],
562
  api_name="run",