prithivMLmods commited on
Commit
905e633
·
verified ·
1 Parent(s): 804f76a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -80
app.py CHANGED
@@ -30,7 +30,6 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
33
- # Load text-only model and tokenizer for chat generation
34
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
35
  tokenizer = AutoTokenizer.from_pretrained(model_id)
36
  model = AutoModelForCausalLM.from_pretrained(
@@ -40,7 +39,6 @@ model = AutoModelForCausalLM.from_pretrained(
40
  )
41
  model.eval()
42
 
43
- # TTS Voices and processor for multimodal chat
44
  TTS_VOICES = [
45
  "en-US-JennyNeural", # @tts1
46
  "en-US-GuyNeural", # @tts2
@@ -53,7 +51,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
53
  torch_dtype=torch.float16
54
  ).to("cuda").eval()
55
 
56
- # A helper function to convert text to speech via Edge TTS
57
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
58
  communicate = edge_tts.Communicate(text, voice)
59
  await communicate.save(output_file)
@@ -66,7 +63,6 @@ def clean_chat_history(chat_history):
66
  cleaned.append(msg)
67
  return cleaned
68
 
69
- # Restricted words check (if any)
70
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
71
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
72
  default_negative = os.getenv("default_negative", "")
@@ -80,7 +76,6 @@ def check_text(prompt, negative=""):
80
  return True
81
  return False
82
 
83
- # Use the same random seed function for both text and image generation
84
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
@@ -92,10 +87,8 @@ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
92
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
93
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
94
 
95
- # Set dtype based on device: use half for CUDA, float32 otherwise.
96
  dtype = torch.float16 if device.type == "cuda" else torch.float32
97
 
98
- # Load image generation pipelines for the three model choices.
99
  if torch.cuda.is_available():
100
  # Lightning 5 model
101
  pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -168,7 +161,6 @@ else:
168
  ).to(device)
169
  print("Running on CPU; models loaded in float32.")
170
 
171
- # Define available model choices and their mapping.
172
  DEFAULT_MODEL = "Lightning 5"
173
  MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4", "Turbo v3"]
174
  models = {
@@ -177,55 +169,11 @@ models = {
177
  "Turbo v3": pipe3
178
  }
179
 
180
- def generate_image_grid(prompt: str, seed: int, grid_size: str, width: int, height: int,
181
- guidance_scale: float, randomize_seed: bool, model_choice: str):
182
- if check_text(prompt, ""):
183
- raise ValueError("Prompt contains restricted words.")
184
-
185
- seed = int(randomize_seed_fn(seed, randomize_seed))
186
- generator = torch.Generator(device=device).manual_seed(seed)
187
-
188
- # Define supported grid sizes.
189
- grid_sizes = {
190
- "2x1": (2, 1),
191
- "1x2": (1, 2),
192
- "2x2": (2, 2),
193
- "1x1": (1, 1)
194
- }
195
- grid_size_tuple = grid_sizes.get(grid_size, (1, 1))
196
- num_images = grid_size_tuple[0] * grid_size_tuple[1]
197
-
198
- options = {
199
- "prompt": prompt,
200
- "negative_prompt": default_negative,
201
- "width": width,
202
- "height": height,
203
- "guidance_scale": guidance_scale,
204
- "num_inference_steps": 30,
205
- "generator": generator,
206
- "num_images_per_prompt": num_images,
207
- "use_resolution_binning": True,
208
- "output_type": "pil",
209
- }
210
-
211
- if device.type == "cuda":
212
- torch.cuda.empty_cache()
213
-
214
- selected_pipe = models.get(model_choice, pipe)
215
- images = selected_pipe(**options).images
216
-
217
- # Create a grid image.
218
- grid_img = Image.new('RGB', (width * grid_size_tuple[0], height * grid_size_tuple[1]))
219
- for i, img in enumerate(images[:num_images]):
220
- grid_img.paste(img, ((i % grid_size_tuple[0]) * width, (i // grid_size_tuple[0]) * height))
221
-
222
  unique_name = str(uuid.uuid4()) + ".png"
223
- grid_img.save(unique_name)
224
- return [unique_name], seed
225
 
226
- # -----------------------------
227
- # Main generate() Function
228
- # -----------------------------
229
  @spaces.GPU
230
  def generate(
231
  input_dict: dict,
@@ -254,37 +202,42 @@ def generate(
254
  elif "@turbov3" in lower_text:
255
  model_choice = "Turbo v3"
256
 
257
- # Parse grid size flag e.g. "@2x2"
258
- grid_match = re.search(r"@(\d+x\d+)", lower_text)
259
- grid_size = grid_match.group(1) if grid_match else "1x1"
260
-
261
- # Remove the model and grid flags from the prompt.
262
  prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE)
263
  prompt_clean = re.sub(r"@lightningv4", "", prompt_clean, flags=re.IGNORECASE)
264
  prompt_clean = re.sub(r"@turbov3", "", prompt_clean, flags=re.IGNORECASE)
265
- prompt_clean = re.sub(r"@\d+x\d+", "", prompt_clean, flags=re.IGNORECASE)
266
  prompt_clean = prompt_clean.strip().strip('"')
267
 
268
- # Default parameters for image generation.
269
  width = 1024
270
  height = 1024
271
  guidance_scale = 6.0
272
  seed_val = 0
273
- randomize_seed = True
274
- use_resolution_binning = True
 
 
275
 
276
- yield "Generating image grid..."
277
- image_paths, used_seed = generate_image_grid(
278
- prompt_clean,
279
- seed_val,
280
- grid_size,
281
- width,
282
- height,
283
- guidance_scale,
284
- randomize_seed,
285
- model_choice,
286
- )
287
- yield gr.Image(image_paths[0])
 
 
 
 
 
 
 
 
288
  return
289
 
290
  # Otherwise, handle text/chat (and TTS) generation.
@@ -358,7 +311,6 @@ def generate(
358
  output_file = asyncio.run(text_to_speech(final_response, voice))
359
  yield gr.Audio(output_file, autoplay=True)
360
 
361
-
362
  DESCRIPTION = """
363
  # IMAGINEO 4K ⚡
364
  """
@@ -388,9 +340,9 @@ demo = gr.ChatInterface(
388
  ],
389
  examples=[
390
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
391
- ['@lightningv5 @2x2 "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"'],
392
- ['@lightningv4 @1x1 "A serene landscape with mountains"'],
393
- ['@turbov3 @2x1 "Abstract art, colorful and vibrant"'],
394
  ["Write a Python function to check if a number is prime."],
395
  ["@tts2 What causes rainbows to form?"],
396
  ],
 
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
 
 
33
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
34
  tokenizer = AutoTokenizer.from_pretrained(model_id)
35
  model = AutoModelForCausalLM.from_pretrained(
 
39
  )
40
  model.eval()
41
 
 
42
  TTS_VOICES = [
43
  "en-US-JennyNeural", # @tts1
44
  "en-US-GuyNeural", # @tts2
 
51
  torch_dtype=torch.float16
52
  ).to("cuda").eval()
53
 
 
54
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
55
  communicate = edge_tts.Communicate(text, voice)
56
  await communicate.save(output_file)
 
63
  cleaned.append(msg)
64
  return cleaned
65
 
 
66
  bad_words = json.loads(os.getenv('BAD_WORDS', "[]"))
67
  bad_words_negative = json.loads(os.getenv('BAD_WORDS_NEGATIVE', "[]"))
68
  default_negative = os.getenv("default_negative", "")
 
76
  return True
77
  return False
78
 
 
79
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
80
  if randomize_seed:
81
  seed = random.randint(0, MAX_SEED)
 
87
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
88
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
89
 
 
90
  dtype = torch.float16 if device.type == "cuda" else torch.float32
91
 
 
92
  if torch.cuda.is_available():
93
  # Lightning 5 model
94
  pipe = StableDiffusionXLPipeline.from_pretrained(
 
161
  ).to(device)
162
  print("Running on CPU; models loaded in float32.")
163
 
 
164
  DEFAULT_MODEL = "Lightning 5"
165
  MODEL_CHOICES = [DEFAULT_MODEL, "Lightning 4", "Turbo v3"]
166
  models = {
 
169
  "Turbo v3": pipe3
170
  }
171
 
172
+ def save_image(img: Image.Image) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  unique_name = str(uuid.uuid4()) + ".png"
174
+ img.save(unique_name)
175
+ return unique_name
176
 
 
 
 
177
  @spaces.GPU
178
  def generate(
179
  input_dict: dict,
 
202
  elif "@turbov3" in lower_text:
203
  model_choice = "Turbo v3"
204
 
205
+ # Remove the model flag from the prompt.
 
 
 
 
206
  prompt_clean = re.sub(r"@lightningv5", "", text, flags=re.IGNORECASE)
207
  prompt_clean = re.sub(r"@lightningv4", "", prompt_clean, flags=re.IGNORECASE)
208
  prompt_clean = re.sub(r"@turbov3", "", prompt_clean, flags=re.IGNORECASE)
 
209
  prompt_clean = prompt_clean.strip().strip('"')
210
 
211
+ # Default parameters for single image generation.
212
  width = 1024
213
  height = 1024
214
  guidance_scale = 6.0
215
  seed_val = 0
216
+ randomize_seed_flag = True
217
+
218
+ seed_val = int(randomize_seed_fn(seed_val, randomize_seed_flag))
219
+ generator = torch.Generator(device=device).manual_seed(seed_val)
220
 
221
+ options = {
222
+ "prompt": prompt_clean,
223
+ "negative_prompt": default_negative,
224
+ "width": width,
225
+ "height": height,
226
+ "guidance_scale": guidance_scale,
227
+ "num_inference_steps": 30,
228
+ "generator": generator,
229
+ "num_images_per_prompt": 1,
230
+ "use_resolution_binning": True,
231
+ "output_type": "pil",
232
+ }
233
+ if device.type == "cuda":
234
+ torch.cuda.empty_cache()
235
+
236
+ selected_pipe = models.get(model_choice, pipe)
237
+ images = selected_pipe(**options).images
238
+ image_path = save_image(images[0])
239
+ yield "Generating image..."
240
+ yield gr.Image(image_path)
241
  return
242
 
243
  # Otherwise, handle text/chat (and TTS) generation.
 
311
  output_file = asyncio.run(text_to_speech(final_response, voice))
312
  yield gr.Audio(output_file, autoplay=True)
313
 
 
314
  DESCRIPTION = """
315
  # IMAGINEO 4K ⚡
316
  """
 
340
  ],
341
  examples=[
342
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
343
+ ['@lightningv5 "Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"'],
344
+ ['@lightningv4 "A serene landscape with mountains"'],
345
+ ['@turbov3 "Abstract art, colorful and vibrant"'],
346
  ["Write a Python function to check if a number is prime."],
347
  ["@tts2 What causes rainbows to form?"],
348
  ],