1inkusFace commited on
Commit
13b3516
·
verified ·
1 Parent(s): 04af224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -47
app.py CHANGED
@@ -106,8 +106,9 @@ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str
106
  negative = ""
107
  return p.replace("{prompt}", positive), n + negative
108
 
 
 
109
  def load_and_prepare_model():
110
- unetX = UNet2DConditionModel.from_pretrained('ford442/RealVisXL_V5.0_BF16', subfolder='unet', low_cpu_mem_usage=False, token=True) #.to(device).to(torch.bfloat16) #.to(device=device, dtype=torch.bfloat16)
111
  vaeX = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", safety_checker=None, use_safetensors=False, low_cpu_mem_usage=False, torch_dtype=torch.float32, token=True) #.to(device).to(torch.bfloat16) #.to(device=device, dtype=torch.bfloat16)
112
  pipe = StableDiffusionXLPipeline.from_pretrained(
113
  'ford442/RealVisXL_V5.0_BF16',
@@ -248,14 +249,14 @@ def captioning(img):
248
  output_prompt=[]
249
  # Initial caption generation without a prompt:
250
  inputsa = processor5(images=img, return_tensors="pt").to('cuda')
251
- generated_ids = model5.generate(**inputsa, min_length=42, max_length=128)
252
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
253
  output_prompt.append(generated_text)
254
  print(generated_text)
255
  # Loop through prompts array:
256
  for prompt in prompts_array:
257
  inputs = processor5(images=img, text=prompt, return_tensors="pt").to('cuda')
258
- generated_ids = model5.generate(**inputs, min_length=32, max_length=64) # Adjust max_length if needed
259
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
260
  response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
261
  output_prompt.append(response_text)
@@ -296,7 +297,7 @@ def expand_prompt(prompt):
296
  outputs = model.generate(
297
  input_ids=input_ids,
298
  attention_mask=attention_mask,
299
- max_new_tokens=256,
300
  temperature=0.2,
301
  top_p=0.9,
302
  do_sample=True,
@@ -304,12 +305,12 @@ def expand_prompt(prompt):
304
  enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
305
  print('-- generated prompt 1 --')
306
  print(enhanced_prompt)
307
-
308
  enhanced_prompt = filter_text(enhanced_prompt,prompt)
309
  enhanced_prompt = filter_text(enhanced_prompt,user_prompt_rewrite)
310
  enhanced_prompt = filter_text(enhanced_prompt,system_prompt_rewrite)
311
  print('-- filtered prompt --')
312
  print(enhanced_prompt)
 
313
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
314
  encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
315
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
@@ -332,7 +333,8 @@ def expand_prompt(prompt):
332
  print('-- filtered prompt 2 --')
333
  print(enhanced_prompt_2)
334
  enh_prompt=[enhanced_prompt,enhanced_prompt_2]
335
- return enh_prompt
 
336
 
337
  @spaces.GPU(duration=40)
338
  def generate_30(
@@ -416,16 +418,10 @@ def generate_30(
416
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
417
  filename= f'rv_IP_{timestamp}.png'
418
  print("-- using image file --")
419
-
420
- caption = flatten_and_stringify(caption)
421
- caption = " ".join(caption)
422
-
423
- caption_2 = flatten_and_stringify(caption_2)
424
- caption_2 = " ".join(caption_2)
425
-
426
- print(caption)
427
- print(caption_2)
428
-
429
  print("-- generating further caption --")
430
  global model5
431
  global processor5
@@ -435,9 +431,7 @@ def generate_30(
435
  gc.collect()
436
  torch.cuda.empty_cache()
437
  expanded = expand_prompt(caption)
438
- expanded_1 = expanded[0]
439
- expanded_2 = expanded[1]
440
- new_prompt = prompt+' '+expanded_1+' '+expanded_2
441
  print("-- ------------ --")
442
  print("-- FINAL PROMPT --")
443
  print(new_prompt)
@@ -451,6 +445,7 @@ def generate_30(
451
  torch.cuda.empty_cache()
452
  global text_encoder_1
453
  global text_encoder_2
 
454
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
455
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
456
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)
@@ -573,17 +568,10 @@ def generate_60(
573
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
574
  filename= f'rv_IP_{timestamp}.png'
575
  print("-- using image file --")
576
-
577
-
578
- caption = flatten_and_stringify(caption)
579
- caption = " ".join(caption)
580
-
581
- caption_2 = flatten_and_stringify(caption_2)
582
- caption_2 = " ".join(caption_2)
583
-
584
- print(caption)
585
- print(caption_2)
586
-
587
  print("-- generating further caption --")
588
  global model5
589
  global processor5
@@ -593,9 +581,7 @@ def generate_60(
593
  gc.collect()
594
  torch.cuda.empty_cache()
595
  expanded = expand_prompt(caption)
596
- expanded_1 = expanded[0]
597
- expanded_2 = expanded[1]
598
- new_prompt = prompt+' '+expanded_1+' '+expanded_2
599
  print("-- ------------ --")
600
  print("-- FINAL PROMPT --")
601
  print(new_prompt)
@@ -609,6 +595,7 @@ def generate_60(
609
  torch.cuda.empty_cache()
610
  global text_encoder_1
611
  global text_encoder_2
 
612
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
613
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
614
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)
@@ -731,17 +718,10 @@ def generate_90(
731
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
732
  filename= f'rv_IP_{timestamp}.png'
733
  print("-- using image file --")
734
-
735
-
736
- caption = flatten_and_stringify(caption)
737
- caption = " ".join(caption)
738
-
739
- caption_2 = flatten_and_stringify(caption_2)
740
- caption_2 = " ".join(caption_2)
741
-
742
- print(caption)
743
- print(caption_2)
744
-
745
  print("-- generating further caption --")
746
  global model5
747
  global processor5
@@ -751,9 +731,7 @@ def generate_90(
751
  gc.collect()
752
  torch.cuda.empty_cache()
753
  expanded = expand_prompt(caption)
754
- expanded_1 = expanded[0]
755
- expanded_2 = expanded[1]
756
- new_prompt = prompt+' '+expanded_1+' '+expanded_2
757
  print("-- ------------ --")
758
  print("-- FINAL PROMPT --")
759
  print(new_prompt)
@@ -767,6 +745,7 @@ def generate_90(
767
  torch.cuda.empty_cache()
768
  global text_encoder_1
769
  global text_encoder_2
 
770
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
771
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
772
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)
 
106
  negative = ""
107
  return p.replace("{prompt}", positive), n + negative
108
 
109
+ unetX = UNet2DConditionModel.from_pretrained('ford442/RealVisXL_V5.0_BF16', subfolder='unet', low_cpu_mem_usage=False, token=True) #.to(device).to(torch.bfloat16) #.to(device=device, dtype=torch.bfloat16)
110
+
111
  def load_and_prepare_model():
 
112
  vaeX = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", safety_checker=None, use_safetensors=False, low_cpu_mem_usage=False, torch_dtype=torch.float32, token=True) #.to(device).to(torch.bfloat16) #.to(device=device, dtype=torch.bfloat16)
113
  pipe = StableDiffusionXLPipeline.from_pretrained(
114
  'ford442/RealVisXL_V5.0_BF16',
 
249
  output_prompt=[]
250
  # Initial caption generation without a prompt:
251
  inputsa = processor5(images=img, return_tensors="pt").to('cuda')
252
+ generated_ids = model5.generate(**inputsa, min_length=42, max_length=64)
253
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
254
  output_prompt.append(generated_text)
255
  print(generated_text)
256
  # Loop through prompts array:
257
  for prompt in prompts_array:
258
  inputs = processor5(images=img, text=prompt, return_tensors="pt").to('cuda')
259
+ generated_ids = model5.generate(**inputs, min_length=32, max_length=42) # Adjust max_length if needed
260
  generated_text = processor5.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
261
  response_text = generated_text.replace(prompt, "").strip() #Or could try .split(prompt, 1)[-1].strip()
262
  output_prompt.append(response_text)
 
297
  outputs = model.generate(
298
  input_ids=input_ids,
299
  attention_mask=attention_mask,
300
+ max_new_tokens=384,
301
  temperature=0.2,
302
  top_p=0.9,
303
  do_sample=True,
 
305
  enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
306
  print('-- generated prompt 1 --')
307
  print(enhanced_prompt)
 
308
  enhanced_prompt = filter_text(enhanced_prompt,prompt)
309
  enhanced_prompt = filter_text(enhanced_prompt,user_prompt_rewrite)
310
  enhanced_prompt = filter_text(enhanced_prompt,system_prompt_rewrite)
311
  print('-- filtered prompt --')
312
  print(enhanced_prompt)
313
+ '''
314
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
315
  encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
316
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
 
333
  print('-- filtered prompt 2 --')
334
  print(enhanced_prompt_2)
335
  enh_prompt=[enhanced_prompt,enhanced_prompt_2]
336
+ '''
337
+ return enhanced_prompt
338
 
339
  @spaces.GPU(duration=40)
340
  def generate_30(
 
418
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
419
  filename= f'rv_IP_{timestamp}.png'
420
  print("-- using image file --")
421
+ captions = caption+caption_2
422
+ captions = flatten_and_stringify(captions)
423
+ captions = " ".join(captions)
424
+ print(captions)
 
 
 
 
 
 
425
  print("-- generating further caption --")
426
  global model5
427
  global processor5
 
431
  gc.collect()
432
  torch.cuda.empty_cache()
433
  expanded = expand_prompt(caption)
434
+ new_prompt = prompt+' '+expanded
 
 
435
  print("-- ------------ --")
436
  print("-- FINAL PROMPT --")
437
  print(new_prompt)
 
445
  torch.cuda.empty_cache()
446
  global text_encoder_1
447
  global text_encoder_2
448
+ global unetX
449
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
450
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
451
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)
 
568
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
569
  filename= f'rv_IP_{timestamp}.png'
570
  print("-- using image file --")
571
+ captions = caption+caption_2
572
+ captions = flatten_and_stringify(captions)
573
+ captions = " ".join(captions)
574
+ print(captions)
 
 
 
 
 
 
 
575
  print("-- generating further caption --")
576
  global model5
577
  global processor5
 
581
  gc.collect()
582
  torch.cuda.empty_cache()
583
  expanded = expand_prompt(caption)
584
+ new_prompt = prompt+' '+expanded
 
 
585
  print("-- ------------ --")
586
  print("-- FINAL PROMPT --")
587
  print(new_prompt)
 
595
  torch.cuda.empty_cache()
596
  global text_encoder_1
597
  global text_encoder_2
598
+ global unetX
599
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
600
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
601
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)
 
718
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
719
  filename= f'rv_IP_{timestamp}.png'
720
  print("-- using image file --")
721
+ captions = caption+caption_2
722
+ captions = flatten_and_stringify(captions)
723
+ captions = " ".join(captions)
724
+ print(captions)
 
 
 
 
 
 
 
725
  print("-- generating further caption --")
726
  global model5
727
  global processor5
 
731
  gc.collect()
732
  torch.cuda.empty_cache()
733
  expanded = expand_prompt(caption)
734
+ new_prompt = prompt+' '+expanded
 
 
735
  print("-- ------------ --")
736
  print("-- FINAL PROMPT --")
737
  print(new_prompt)
 
745
  torch.cuda.empty_cache()
746
  global text_encoder_1
747
  global text_encoder_2
748
+ global unetX
749
  pipe.text_encoder=text_encoder_1.to(device=device, dtype=torch.bfloat16)
750
  pipe.text_encoder_2=text_encoder_2.to(device=device, dtype=torch.bfloat16)
751
  pipe.unet=unetX.to(device=device, dtype=torch.bfloat16)