prithivMLmods commited on
Commit
786abd0
·
verified ·
1 Parent(s): eb72d75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -454
app.py CHANGED
@@ -2,15 +2,71 @@ import os
2
  import random
3
  import uuid
4
  import json
 
 
 
 
 
5
  import gradio as gr
6
- import numpy as np
7
- from PIL import Image
8
  import spaces
9
  import torch
10
- from diffusers import DiffusionPipeline, StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
11
- from typing import Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load restricted words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
15
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
16
  default_negative = os.getenv("default_negative", "")
@@ -24,218 +80,11 @@ def check_text(prompt, negative=""):
24
  return True
25
  return False
26
 
27
- # Quality/Style--------------------------------------------------------------------
28
- style_list = [
29
- {
30
- "name": "3840 x 2160",
31
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
32
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
33
- },
34
- {
35
- "name": "2560 x 1440",
36
- "prompt": "hyper-realistic 4K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
37
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
38
- },
39
- {
40
- "name": "HD+",
41
- "prompt": "hyper-realistic 2K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
42
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
43
- },
44
- {
45
- "name": "Style Zero",
46
- "prompt": "{prompt}",
47
- "negative_prompt": "",
48
- },
49
- ]
50
-
51
- # Collage styles--------------------------------------------------------------------
52
- collage_style_list = [
53
- {
54
- "name": "Hi-Res",
55
- "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
56
- "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
57
- },
58
- {
59
- "name": "B & W",
60
- "prompt": "black and white collage of {prompt}. monochromatic, timeless, classic, dramatic contrast",
61
- "negative_prompt": "colorful, vibrant, bright, flashy",
62
- },
63
- {
64
- "name": "Polaroid",
65
- "prompt": "collage of polaroid photos featuring {prompt}. vintage style, high contrast, nostalgic, instant film aesthetic",
66
- "negative_prompt": "digital, modern, low quality, blurry",
67
- },
68
- {
69
- "name": "Watercolor",
70
- "prompt": "watercolor collage of {prompt}. soft edges, translucent colors, painterly effects",
71
- "negative_prompt": "digital, sharp lines, solid colors",
72
- },
73
- {
74
- "name": "Cinematic",
75
- "prompt": "cinematic collage of {prompt}. film stills, movie posters, dramatic lighting",
76
- "negative_prompt": "static, lifeless, mundane",
77
- },
78
- {
79
- "name": "Nostalgic",
80
- "prompt": "nostalgic collage of {prompt}. retro imagery, vintage objects, sentimental journey",
81
- "negative_prompt": "contemporary, futuristic, forward-looking",
82
- },
83
- {
84
- "name": "Vintage",
85
- "prompt": "vintage collage of {prompt}. aged paper, sepia tones, retro imagery, antique vibes",
86
- "negative_prompt": "modern, contemporary, futuristic, high-tech",
87
- },
88
- {
89
- "name": "Scrapbook",
90
- "prompt": "scrapbook style collage of {prompt}. mixed media, hand-cut elements, textures, paper, stickers, doodles",
91
- "negative_prompt": "clean, digital, modern, low quality",
92
- },
93
- {
94
- "name": "NeoNGlow",
95
- "prompt": "neon glow collage of {prompt}. vibrant colors, glowing effects, futuristic vibes",
96
- "negative_prompt": "dull, muted colors, vintage, retro",
97
- },
98
- {
99
- "name": "Geometric",
100
- "prompt": "geometric collage of {prompt}. abstract shapes, colorful, sharp edges, modern design, high quality",
101
- "negative_prompt": "blurry, low quality, traditional, dull",
102
- },
103
- {
104
- "name": "Thematic",
105
- "prompt": "thematic collage of {prompt}. cohesive theme, well-organized, matching colors, creative layout",
106
- "negative_prompt": "random, messy, unorganized, clashing colors",
107
- },
108
- {
109
- "name": "Cherry",
110
- "prompt": "Duotone style Cherry tone applied to {prompt}",
111
- "negative_prompt": "",
112
- },
113
- {
114
- "name": "Fuchsia",
115
- "prompt": "Duotone style Fuchsia tone applied to {prompt}",
116
- "negative_prompt": "",
117
- },
118
- {
119
- "name": "Pop",
120
- "prompt": "Duotone style Pop tone applied to {prompt}",
121
- "negative_prompt": "",
122
- },
123
- {
124
- "name": "Violet",
125
- "prompt": "Duotone style Violet applied to {prompt}",
126
- "negative_prompt": "",
127
- },
128
- {
129
- "name": "Sea Blue",
130
- "prompt": "Duotone style Sea Blue applied to {prompt}",
131
- "negative_prompt": "",
132
- },
133
- {
134
- "name": "Sea Green",
135
- "prompt": "Duotone style Sea Green applied to {prompt}",
136
- "negative_prompt": "",
137
- },
138
- {
139
- "name": "Mustard",
140
- "prompt": "Duotone style Mustard applied to {prompt}",
141
- "negative_prompt": "",
142
- },
143
- {
144
- "name": "Amber",
145
- "prompt": "Duotone style Amber applied to {prompt}",
146
- "negative_prompt": "",
147
- },
148
- {
149
- "name": "Pomelo",
150
- "prompt": "Duotone style Pomelo applied to {prompt}",
151
- "negative_prompt": "",
152
- },
153
- {
154
- "name": "Peppermint",
155
- "prompt": "Duotone style Peppermint applied to {prompt}",
156
- "negative_prompt": "",
157
- },
158
- {
159
- "name": "Mystic",
160
- "prompt": "Duotone style Mystic tone applied to {prompt}",
161
- "negative_prompt": "",
162
- },
163
- {
164
- "name": "Pastel",
165
- "prompt": "Duotone style Pastel applied to {prompt}",
166
- "negative_prompt": "",
167
- },
168
- {
169
- "name": "Coral",
170
- "prompt": "Duotone style Coral applied to {prompt}",
171
- "negative_prompt": "",
172
- },
173
- {
174
- "name": "No Style",
175
- "prompt": "{prompt}",
176
- "negative_prompt": "",
177
- },
178
- ]
179
-
180
- # Filters--------------------------------------------------------------------
181
- filters = {
182
- "Vivid": {
183
- "prompt": "extra vivid {prompt}",
184
- "negative_prompt": "washed out, dull"
185
- },
186
- "Playa": {
187
- "prompt": "{prompt} set in a vast playa",
188
- "negative_prompt": "forest, mountains"
189
- },
190
- "Desert": {
191
- "prompt": "{prompt} set in a desert landscape",
192
- "negative_prompt": "ocean, city"
193
- },
194
- "West": {
195
- "prompt": "{prompt} with a western theme",
196
- "negative_prompt": "eastern, modern"
197
- },
198
- "Blush": {
199
- "prompt": "{prompt} with a soft blush color palette",
200
- "negative_prompt": "harsh colors, neon"
201
- },
202
- "Minimalist": {
203
- "prompt": "{prompt} with a minimalist design",
204
- "negative_prompt": "cluttered, ornate"
205
- },
206
- "Zero filter": {
207
- "prompt": "{prompt}",
208
- "negative_prompt": ""
209
- },
210
- }
211
-
212
- styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
213
- collage_styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in collage_style_list}
214
- filter_styles = {k: (v["prompt"], v["negative_prompt"]) for k, v in filters.items()}
215
-
216
- STYLE_NAMES = list(styles.keys())
217
- COLLAGE_STYLE_NAMES = list(collage_styles.keys())
218
- FILTER_NAMES = list(filters.keys())
219
- DEFAULT_STYLE_NAME = "3840 x 2160"
220
- DEFAULT_COLLAGE_STYLE_NAME = "Hi-Res"
221
- DEFAULT_FILTER_NAME = "Zero filter"
222
-
223
- def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
224
- if style_name in styles:
225
- p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
226
- elif style_name in collage_styles:
227
- p, n = collage_styles.get(style_name, collage_styles[DEFAULT_COLLAGE_STYLE_NAME])
228
- elif style_name in filter_styles:
229
- p, n = filter_styles.get(style_name, filter_styles[DEFAULT_FILTER_NAME])
230
- else:
231
- p, n = styles[DEFAULT_STYLE_NAME]
232
-
233
- if not negative:
234
- negative = ""
235
- return p.replace("{prompt}", positive), n + negative
236
-
237
- if not torch.cuda.is_available():
238
- DESCRIPTION = "\n<p>⚠️Running on CPU, This may not work on CPU.</p>"
239
 
240
  MAX_SEED = np.iinfo(np.int32).max
241
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
@@ -243,53 +92,46 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
243
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
244
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
245
 
246
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
247
- # Set dtype based on device: half for CUDA, float32 for CPU
248
  dtype = torch.float16 if device.type == "cuda" else torch.float32
249
 
250
- # Load primary model (RealVisXL_V5.0_Lightning)
251
  if torch.cuda.is_available():
 
252
  pipe = StableDiffusionXLPipeline.from_pretrained(
253
- #"SG161222/RealVisXL_V5.0",
254
  "SG161222/RealVisXL_V5.0_Lightning",
255
  torch_dtype=dtype,
256
  use_safetensors=True,
257
  add_watermarker=False
258
  ).to(device)
259
- # Ensure text encoder uses half precision on GPU
260
  pipe.text_encoder = pipe.text_encoder.half()
261
-
262
  if ENABLE_CPU_OFFLOAD:
263
  pipe.enable_model_cpu_offload()
264
  else:
265
  pipe.to(device)
266
  print("Loaded RealVisXL_V5.0_Lightning on Device!")
267
-
268
  if USE_TORCH_COMPILE:
269
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
270
  print("Model RealVisXL_V5.0_Lightning Compiled!")
271
 
272
- # Load second model (RealVisXL_V4.0)
273
  pipe2 = StableDiffusionXLPipeline.from_pretrained(
274
- #"SG161222/RealVisXL_V4.0",
275
  "SG161222/RealVisXL_V4.0_Lightning",
276
  torch_dtype=dtype,
277
  use_safetensors=True,
278
  add_watermarker=False,
279
  ).to(device)
280
  pipe2.text_encoder = pipe2.text_encoder.half()
281
-
282
  if ENABLE_CPU_OFFLOAD:
283
  pipe2.enable_model_cpu_offload()
284
  else:
285
  pipe2.to(device)
286
  print("Loaded RealVisXL_V4.0 on Device!")
287
-
288
  if USE_TORCH_COMPILE:
289
  pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
290
  print("Model RealVisXL_V4.0 Compiled!")
291
 
292
- # Load third model
293
  pipe3 = StableDiffusionXLPipeline.from_pretrained(
294
  "SG161222/RealVisXL_V3.0_Turbo",
295
  torch_dtype=dtype,
@@ -297,18 +139,15 @@ if torch.cuda.is_available():
297
  add_watermarker=False,
298
  ).to(device)
299
  pipe3.text_encoder = pipe3.text_encoder.half()
300
-
301
  if ENABLE_CPU_OFFLOAD:
302
  pipe3.enable_model_cpu_offload()
303
  else:
304
  pipe3.to(device)
305
  print("Loaded Animagine XL 4.0 on Device!")
306
-
307
  if USE_TORCH_COMPILE:
308
  pipe3.unet = torch.compile(pipe3.unet, mode="reduce-overhead", fullgraph=True)
309
  print("Model Animagine XL 4.0 Compiled!")
310
  else:
311
- # On CPU, load all models in float32
312
  pipe = StableDiffusionXLPipeline.from_pretrained(
313
  "SG161222/RealVisXL_V5.0_Lightning",
314
  torch_dtype=dtype,
@@ -329,7 +168,7 @@ else:
329
  ).to(device)
330
  print("Running on CPU; models loaded in float32.")
331
 
332
- # A dictionary to easily choose the model based on selection.
333
  DEFAULT_MODEL = "Lightning 5"
334
  MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4", "Turbo v3"]
335
  models = {
@@ -338,266 +177,232 @@ models = {
338
  "Turbo v3": pipe3
339
  }
340
 
341
- def save_image(img, path):
342
- img.save(path)
343
-
344
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
345
- if randomize_seed:
346
- seed = random.randint(0, MAX_SEED)
347
- return seed
348
-
349
- @spaces.GPU
350
- def generate(
351
- prompt: str,
352
- negative_prompt: str = "",
353
- use_negative_prompt: bool = False,
354
- style: str = DEFAULT_STYLE_NAME,
355
- collage_style: str = DEFAULT_COLLAGE_STYLE_NAME,
356
- filter_name: str = DEFAULT_FILTER_NAME,
357
- grid_size: str = "2x2",
358
- seed: int = 0,
359
- width: int = 1024,
360
- height: int = 1024,
361
- guidance_scale: float = 3,
362
- randomize_seed: bool = False,
363
- model_choice: str = DEFAULT_MODEL,
364
- use_resolution_binning: bool = True,
365
- progress=gr.Progress(track_tqdm=True),
366
- ):
367
- if check_text(prompt, negative_prompt):
368
  raise ValueError("Prompt contains restricted words.")
369
 
370
- if collage_style != "No Style":
371
- prompt, negative_prompt = apply_style(collage_style, prompt, negative_prompt)
372
- elif filter_name != "No Filter":
373
- prompt, negative_prompt = apply_style(filter_name, prompt, negative_prompt)
374
- else:
375
- prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
376
-
377
  seed = int(randomize_seed_fn(seed, randomize_seed))
378
  generator = torch.Generator(device=device).manual_seed(seed)
379
 
380
- if not use_negative_prompt:
381
- negative_prompt = ""
382
- negative_prompt += default_negative
383
-
384
  grid_sizes = {
385
  "2x1": (2, 1),
386
  "1x2": (1, 2),
387
  "2x2": (2, 2),
388
- "2x3": (2, 3),
389
- "3x2": (3, 2),
390
  "1x1": (1, 1)
391
  }
392
-
393
- grid_size_x, grid_size_y = grid_sizes.get(grid_size, (2, 2))
394
- num_images = grid_size_x * grid_size_y
395
 
396
  options = {
397
  "prompt": prompt,
398
- "negative_prompt": negative_prompt,
399
  "width": width,
400
  "height": height,
401
  "guidance_scale": guidance_scale,
402
  "num_inference_steps": 30,
403
  "generator": generator,
404
  "num_images_per_prompt": num_images,
405
- "use_resolution_binning": use_resolution_binning,
406
  "output_type": "pil",
407
  }
408
 
409
  if device.type == "cuda":
410
  torch.cuda.empty_cache()
411
 
412
- # Choose pipeline based on user selection
413
  selected_pipe = models.get(model_choice, pipe)
414
  images = selected_pipe(**options).images
415
 
416
- grid_img = Image.new('RGB', (width * grid_size_x, height * grid_size_y))
 
417
  for i, img in enumerate(images[:num_images]):
418
- grid_img.paste(img, (i % grid_size_x * width, i // grid_size_x * height))
419
 
420
  unique_name = str(uuid.uuid4()) + ".png"
421
- save_image(grid_img, unique_name)
422
  return [unique_name], seed
423
 
424
- examples = [
425
- "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic oil --ar 2:3 --q 2 --s 750 --v 5",
426
- "3d image, cute girl, in the style of Pixar --ar 1:2 --stylize 750, 4K resolution highlights, Sharp focus, octane render, ray tracing, Ultra-High-Definition, 8k, UHD, HDR, (Masterpiece:1.5), (best quality:1.5)",
427
- "Cold coffee in a cup bokeh --ar 85:128 --v 6.0 --style raw5, 4k hdr, retro",
428
- "Super Realism, High-resolution photograph, woman, UHD, photorealistic, shot on a Sony A7III --chaos 20 --ar 1:2 --style raw --stylize 250 --realism --soft"
429
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  css = '''
432
- .gradio-container {
433
- max-width: 888px !important;
434
- margin: 0 auto !important;
435
- display: flex;
436
- flex-direction: column;
437
- align-items: center;
438
- }
439
  h1 {
440
- text-align: center;
 
441
  }
442
- '''
443
 
444
- title = """<h1 align="center">IMAGINEO 4K : SDXL🔥</h1>
445
- <p><center>
446
- <a href="https://huggingface.co/SG161222/RealVisXL_V4.0_Lightning" target="_blank">[Lightning 4]</a>
447
- <a href="https://huggingface.co/SG161222/RealVisXL_V5.0_Lightning" target="_blank">[Lightning 5]</a>
448
- <a href="https://huggingface.co/SG161222/RealVisXL_V3.0_Turbo" target="_blank">[Turbo v3]</a>
449
- </center></p>
450
- """
451
 
452
- with gr.Blocks(theme="YTheme/Minecraft", css=css) as demo:
453
- gr.HTML(title)
454
- with gr.Row():
455
- with gr.Column(scale=1):
456
- prompt = gr.Text(
457
- label="Prompt",
458
- show_label=False,
459
- max_lines=1,
460
- placeholder="Enter your prompt",
461
- container=False,
462
- )
463
- run_button = gr.Button("Generate Image ( 1024 x 1024 ) 🧤", scale=0)
464
-
465
- with gr.Row(visible=True):
466
- model_selection = gr.Dropdown(
467
- choices=MODEL_CHOICES,
468
- value=DEFAULT_MODEL,
469
- label="Model Selection",
470
- )
471
- with gr.Row(visible=True):
472
- grid_size_selection = gr.Dropdown(
473
- choices=["2x1", "1x2", "2x2", "2x3", "3x2", "1x1"],
474
- value="1x1",
475
- label="Grid Size"
476
- )
477
- with gr.Row(visible=True):
478
- filter_selection = gr.Dropdown(
479
- show_label=True,
480
- container=True,
481
- interactive=True,
482
- choices=FILTER_NAMES,
483
- value=DEFAULT_FILTER_NAME,
484
- label="Filter Type",
485
- )
486
- with gr.Row(visible=True):
487
- collage_style_selection = gr.Dropdown(
488
- show_label=True,
489
- container=True,
490
- interactive=True,
491
- choices=COLLAGE_STYLE_NAMES,
492
- value=DEFAULT_COLLAGE_STYLE_NAME,
493
- label="Collage Template + Duotone Canvas",
494
- )
495
- with gr.Row(visible=True):
496
- style_selection = gr.Dropdown(
497
- show_label=True,
498
- container=True,
499
- interactive=True,
500
- choices=STYLE_NAMES,
501
- value=DEFAULT_STYLE_NAME,
502
- label="Quality Style",
503
- )
504
- with gr.Accordion("Advanced options", open=False):
505
- use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True, visible=True)
506
- negative_prompt = gr.Text(
507
- label="Negative prompt",
508
- max_lines=1,
509
- placeholder="Enter a negative prompt",
510
- value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
511
- visible=True,
512
- )
513
- with gr.Row():
514
- num_inference_steps = gr.Slider(
515
- label="Steps",
516
- minimum=10,
517
- maximum=60,
518
- step=1,
519
- value=30,
520
- )
521
- with gr.Row():
522
- num_images_per_prompt = gr.Slider(
523
- label="Images",
524
- minimum=1,
525
- maximum=5,
526
- step=1,
527
- value=2,
528
- )
529
- seed = gr.Slider(
530
- label="Seed",
531
- minimum=0,
532
- maximum=MAX_SEED,
533
- step=1,
534
- value=0,
535
- visible=True
536
- )
537
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
538
- with gr.Row(visible=True):
539
- width = gr.Slider(
540
- label="Width",
541
- minimum=512,
542
- maximum=2048,
543
- step=64,
544
- value=1024,
545
- )
546
- height = gr.Slider(
547
- label="Height",
548
- minimum=512,
549
- maximum=2048,
550
- step=64,
551
- value=1024,
552
- )
553
- with gr.Row():
554
- guidance_scale = gr.Slider(
555
- label="Guidance Scale",
556
- minimum=0.1,
557
- maximum=20.0,
558
- step=0.1,
559
- value=6,
560
- )
561
- with gr.Column(scale=2):
562
- result = gr.Gallery(label="Result", columns=1, show_label=False)
563
- gr.Examples(
564
- examples=examples,
565
- inputs=prompt,
566
- outputs=[result, seed],
567
- fn=generate,
568
- cache_examples=CACHE_EXAMPLES,
569
- )
570
- use_negative_prompt.change(
571
- fn=lambda x: gr.update(visible=x),
572
- inputs=use_negative_prompt,
573
- outputs=negative_prompt,
574
- api_name=False,
575
- )
576
- gr.on(
577
- triggers=[
578
- prompt.submit,
579
- negative_prompt.submit,
580
- run_button.click,
581
- ],
582
- fn=generate,
583
- inputs=[
584
- prompt,
585
- negative_prompt,
586
- use_negative_prompt,
587
- style_selection,
588
- collage_style_selection,
589
- filter_selection,
590
- grid_size_selection,
591
- seed,
592
- width,
593
- height,
594
- guidance_scale,
595
- randomize_seed,
596
- model_selection,
597
- ],
598
- outputs=[result, seed],
599
- api_name="run",
600
- )
601
 
602
  if __name__ == "__main__":
603
- demo.queue(max_size=40).launch()
 
2
  import random
3
  import uuid
4
  import json
5
+ import time
6
+ import asyncio
7
+ import re
8
+ from threading import Thread
9
+
10
  import gradio as gr
 
 
11
  import spaces
12
  import torch
13
+ import numpy as np
14
+ from PIL import Image
15
+ import edge_tts
16
+
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ TextIteratorStreamer,
21
+ Qwen2VLForConditionalGeneration,
22
+ AutoProcessor,
23
+ )
24
+ from transformers.image_utils import load_image
25
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
+
27
+ MAX_MAX_NEW_TOKENS = 2048
28
+ DEFAULT_MAX_NEW_TOKENS = 1024
29
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+
33
+ # Load text-only model and tokenizer for chat generation
34
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
35
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id,
38
+ device_map="auto",
39
+ torch_dtype=torch.bfloat16,
40
+ )
41
+ model.eval()
42
 
43
+ # TTS Voices and processor for multimodal chat
44
+ TTS_VOICES = [
45
+ "en-US-JennyNeural", # @tts1
46
+ "en-US-GuyNeural", # @tts2
47
+ ]
48
+ MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
49
+ processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
50
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
51
+ MODEL_ID_VL,
52
+ trust_remote_code=True,
53
+ torch_dtype=torch.float16
54
+ ).to("cuda").eval()
55
+
56
+ # A helper function to convert text to speech via Edge TTS
57
+ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
58
+ communicate = edge_tts.Communicate(text, voice)
59
+ await communicate.save(output_file)
60
+ return output_file
61
+
62
+ def clean_chat_history(chat_history):
63
+ cleaned = []
64
+ for msg in chat_history:
65
+ if isinstance(msg, dict) and isinstance(msg.get("content"), str):
66
+ cleaned.append(msg)
67
+ return cleaned
68
+
69
+ # Restricted words check (if any)
70
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
71
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
72
  default_negative = os.getenv("default_negative", "")
 
80
  return True
81
  return False
82
 
83
+ # Use the same random seed function for both text and image generation
84
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
85
+ if randomize_seed:
86
+ seed = random.randint(0, MAX_SEED)
87
+ return seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  MAX_SEED = np.iinfo(np.int32).max
90
  CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
 
92
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
93
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
94
 
95
+ # Set dtype based on device: use half for CUDA, float32 otherwise.
 
96
  dtype = torch.float16 if device.type == "cuda" else torch.float32
97
 
98
+ # Load image generation pipelines for the three model choices.
99
  if torch.cuda.is_available():
100
+ # Lightning 5 model
101
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
102
  "SG161222/RealVisXL_V5.0_Lightning",
103
  torch_dtype=dtype,
104
  use_safetensors=True,
105
  add_watermarker=False
106
  ).to(device)
 
107
  pipe.text_encoder = pipe.text_encoder.half()
 
108
  if ENABLE_CPU_OFFLOAD:
109
  pipe.enable_model_cpu_offload()
110
  else:
111
  pipe.to(device)
112
  print("Loaded RealVisXL_V5.0_Lightning on Device!")
 
113
  if USE_TORCH_COMPILE:
114
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
115
  print("Model RealVisXL_V5.0_Lightning Compiled!")
116
 
117
+ # Lightning 4 model
118
  pipe2 = StableDiffusionXLPipeline.from_pretrained(
 
119
  "SG161222/RealVisXL_V4.0_Lightning",
120
  torch_dtype=dtype,
121
  use_safetensors=True,
122
  add_watermarker=False,
123
  ).to(device)
124
  pipe2.text_encoder = pipe2.text_encoder.half()
 
125
  if ENABLE_CPU_OFFLOAD:
126
  pipe2.enable_model_cpu_offload()
127
  else:
128
  pipe2.to(device)
129
  print("Loaded RealVisXL_V4.0 on Device!")
 
130
  if USE_TORCH_COMPILE:
131
  pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
132
  print("Model RealVisXL_V4.0 Compiled!")
133
 
134
+ # Turbo v3 model
135
  pipe3 = StableDiffusionXLPipeline.from_pretrained(
136
  "SG161222/RealVisXL_V3.0_Turbo",
137
  torch_dtype=dtype,
 
139
  add_watermarker=False,
140
  ).to(device)
141
  pipe3.text_encoder = pipe3.text_encoder.half()
 
142
  if ENABLE_CPU_OFFLOAD:
143
  pipe3.enable_model_cpu_offload()
144
  else:
145
  pipe3.to(device)
146
  print("Loaded Animagine XL 4.0 on Device!")
 
147
  if USE_TORCH_COMPILE:
148
  pipe3.unet = torch.compile(pipe3.unet, mode="reduce-overhead", fullgraph=True)
149
  print("Model Animagine XL 4.0 Compiled!")
150
  else:
 
151
  pipe = StableDiffusionXLPipeline.from_pretrained(
152
  "SG161222/RealVisXL_V5.0_Lightning",
153
  torch_dtype=dtype,
 
168
  ).to(device)
169
  print("Running on CPU; models loaded in float32.")
170
 
171
+ # Define available model choices and their mapping.
172
  DEFAULT_MODEL = "Lightning 5"
173
  MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4", "Turbo v3"]
174
  models = {
 
177
  "Turbo v3": pipe3
178
  }
179
 
180
+ def generate_image_grid(prompt: str, seed: int, grid_size: str, width: int, height: int,
181
+ guidance_scale: float, randomize_seed: bool, model_choice: str):
182
+ if check_text(prompt, ""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  raise ValueError("Prompt contains restricted words.")
184
 
 
 
 
 
 
 
 
185
  seed = int(randomize_seed_fn(seed, randomize_seed))
186
  generator = torch.Generator(device=device).manual_seed(seed)
187
 
188
+ # Define supported grid sizes.
 
 
 
189
  grid_sizes = {
190
  "2x1": (2, 1),
191
  "1x2": (1, 2),
192
  "2x2": (2, 2),
 
 
193
  "1x1": (1, 1)
194
  }
195
+ grid_size_tuple = grid_sizes.get(grid_size, (1, 1))
196
+ num_images = grid_size_tuple[0] * grid_size_tuple[1]
 
197
 
198
  options = {
199
  "prompt": prompt,
200
+ "negative_prompt": default_negative,
201
  "width": width,
202
  "height": height,
203
  "guidance_scale": guidance_scale,
204
  "num_inference_steps": 30,
205
  "generator": generator,
206
  "num_images_per_prompt": num_images,
207
+ "use_resolution_binning": True,
208
  "output_type": "pil",
209
  }
210
 
211
  if device.type == "cuda":
212
  torch.cuda.empty_cache()
213
 
 
214
  selected_pipe = models.get(model_choice, pipe)
215
  images = selected_pipe(**options).images
216
 
217
+ # Create a grid image.
218
+ grid_img = Image.new('RGB', (width * grid_size_tuple[0], height * grid_size_tuple[1]))
219
  for i, img in enumerate(images[:num_images]):
220
+ grid_img.paste(img, ((i % grid_size_tuple[0]) * width, (i // grid_size_tuple[0]) * height))
221
 
222
  unique_name = str(uuid.uuid4()) + ".png"
223
+ grid_img.save(unique_name)
224
  return [unique_name], seed
225
 
226
+ # -----------------------------
227
+ # Main generate() Function
228
+ # -----------------------------
229
+ @spaces.GPU
230
+ def generate(
231
+ input_dict: dict,
232
+ chat_history: list[dict],
233
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
234
+ temperature: float = 0.6,
235
+ top_p: float = 0.9,
236
+ top_k: int = 50,
237
+ repetition_penalty: float = 1.2,
238
+ ):
239
+ text = input_dict["text"]
240
+ files = input_dict.get("files", [])
241
+
242
+ lower_text = text.lower().strip()
243
+ # Check if the prompt is an image generation command using model flags.
244
+ if (lower_text.startswith("@lightningv5") or
245
+ lower_text.startswith("@lightningv4") or
246
+ lower_text.startswith("@turbov3")):
247
+
248
+ # Determine model choice based on flag.
249
+ model_choice = None
250
+ if "@lightningv5" in lower_text:
251
+ model_choice = "Lightning 5"
252
+ elif "@lightningv4" in lower_text:
253
+ model_choice = "Lightning 4"
254
+ elif "@turbov3" in lower_text:
255
+ model_choice = "Turbo v3"
256
+
257
+ # Parse grid size flag e.g. "@2x2"
258
+ grid_match = re.search(r"@(\d+x\d+)", lower_text)
259
+ grid_size = grid_match.group(1) if grid_match else "1x1"
260
+
261
+ # Remove the model and grid flags from the prompt.
262
+ prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE)
263
+ prompt_clean = re.sub(r"@lightningv4", "", prompt_clean, flags=re.IGNORECASE)
264
+ prompt_clean = re.sub(r"@turbov3", "", prompt_clean, flags=re.IGNORECASE)
265
+ prompt_clean = re.sub(r"@\d+x\d+", "", prompt_clean, flags=re.IGNORECASE)
266
+ prompt_clean = prompt_clean.strip().strip('"')
267
+
268
+ # Default parameters for image generation.
269
+ width = 1024
270
+ height = 1024
271
+ guidance_scale = 6.0
272
+ seed_val = 0
273
+ randomize_seed = True
274
+ use_resolution_binning = True
275
+
276
+ yield "Generating image grid..."
277
+ image_paths, used_seed = generate_image_grid(
278
+ prompt_clean,
279
+ seed_val,
280
+ grid_size,
281
+ width,
282
+ height,
283
+ guidance_scale,
284
+ randomize_seed,
285
+ model_choice,
286
+ )
287
+ yield gr.Image(image_paths[0])
288
+ return
289
+
290
+ # Otherwise, handle text/chat (and TTS) generation.
291
+ tts_prefix = "@tts"
292
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
293
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
294
+
295
+ if is_tts and voice_index:
296
+ voice = TTS_VOICES[voice_index - 1]
297
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
298
+ conversation = [{"role": "user", "content": text}]
299
+ else:
300
+ voice = None
301
+ text = text.replace(tts_prefix, "").strip()
302
+ conversation = clean_chat_history(chat_history)
303
+ conversation.append({"role": "user", "content": text})
304
+
305
+ if files:
306
+ images = [load_image(image) for image in files] if len(files) > 1 else [load_image(files[0])]
307
+ messages = [{
308
+ "role": "user",
309
+ "content": [
310
+ *[{"type": "image", "image": image} for image in images],
311
+ {"type": "text", "text": text},
312
+ ]
313
+ }]
314
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
315
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
316
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
317
+ generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
318
+ thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
319
+ thread.start()
320
+
321
+ buffer = ""
322
+ yield "Thinking..."
323
+ for new_text in streamer:
324
+ buffer += new_text
325
+ buffer = buffer.replace("<|im_end|>", "")
326
+ time.sleep(0.01)
327
+ yield buffer
328
+ else:
329
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
330
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
331
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
332
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
333
+ input_ids = input_ids.to(model.device)
334
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
335
+ generation_kwargs = {
336
+ "input_ids": input_ids,
337
+ "streamer": streamer,
338
+ "max_new_tokens": max_new_tokens,
339
+ "do_sample": True,
340
+ "top_p": top_p,
341
+ "top_k": top_k,
342
+ "temperature": temperature,
343
+ "num_beams": 1,
344
+ "repetition_penalty": repetition_penalty,
345
+ }
346
+ t = Thread(target=model.generate, kwargs=generation_kwargs)
347
+ t.start()
348
+
349
+ outputs = []
350
+ for new_text in streamer:
351
+ outputs.append(new_text)
352
+ yield "".join(outputs)
353
+
354
+ final_response = "".join(outputs)
355
+ yield final_response
356
+
357
+ if is_tts and voice:
358
+ output_file = asyncio.run(text_to_speech(final_response, voice))
359
+ yield gr.Audio(output_file, autoplay=True)
360
+
361
+
362
+ DESCRIPTION = """
363
+ # IMAGINEO 4K ⚡
364
+ """
365
 
366
  css = '''
 
 
 
 
 
 
 
367
  h1 {
368
+ text-align: center;
369
+ display: block;
370
  }
 
371
 
372
+ #duplicate-button {
373
+ margin: auto;
374
+ color: #fff;
375
+ background: #1565c0;
376
+ border-radius: 100vh;
377
+ }
378
+ '''
379
 
380
+ demo = gr.ChatInterface(
381
+ fn=generate,
382
+ additional_inputs=[
383
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
384
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
385
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
386
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
387
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
388
+ ],
389
+ examples=[
390
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
391
+ ['@lightningv5 @2x2 "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"'],
392
+ ['@lightningv4 @1x1 "A serene landscape with mountains"'],
393
+ ['@turbov3 @2x1 "Abstract art, colorful and vibrant"'],
394
+ ["Write a Python function to check if a number is prime."],
395
+ ["@tts2 What causes rainbows to form?"],
396
+ ],
397
+ cache_examples=False,
398
+ type="messages",
399
+ description=DESCRIPTION,
400
+ css=css,
401
+ fill_height=True,
402
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
403
+ stop_btn="Stop Generation",
404
+ multimodal=True,
405
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  if __name__ == "__main__":
408
+ demo.queue(max_size=20).launch(share=True)