1inkusFace commited on
Commit
413bf66
·
verified ·
1 Parent(s): d0a5976

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -279,37 +279,37 @@ def expand_prompt(prompt):
279
  "Rephrase this scene to have more elaborate details: "
280
  )
281
  input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
282
- input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
283
  print("-- got prompt --")
284
  # Encode the input text and include the attention mask
285
  encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True).to("cuda:0")
286
- encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
287
  # Ensure all values are on the correct device
288
  input_ids = encoded_inputs["input_ids"].to("cuda:0")
289
- input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
290
  attention_mask = encoded_inputs["attention_mask"].to("cuda:0")
291
- attention_mask_2 = encoded_inputs_2["attention_mask"].to("cuda:0")
292
  print("-- tokenize prompt --")
293
  # Google T5
294
  #input_ids = txt_tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
295
  outputs = model.generate(
296
  input_ids=input_ids,
297
  attention_mask=attention_mask,
298
- max_new_tokens=512,
299
  temperature=0.2,
300
  top_p=0.9,
301
  do_sample=True,
302
  )
 
 
 
 
 
303
  outputs_2 = model.generate(
304
  input_ids=input_ids_2,
305
  attention_mask=attention_mask_2,
306
- max_new_tokens=65,
307
  temperature=0.2,
308
  top_p=0.9,
309
  do_sample=True,
310
  )
311
  # Use the encoded tensor 'text_inputs' here
312
- enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
313
  enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
314
  print('-- generated prompt --')
315
  enhanced_prompt = filter_text(enhanced_prompt,prompt)
@@ -404,7 +404,7 @@ def generate_30(
404
 
405
  expand_prompt(prompt)
406
  expand_prompt(caption)
407
- expand_prompt(caption_2)
408
 
409
  print('-- generating image --')
410
  sd_image = ip_model.generate(
@@ -414,6 +414,7 @@ def generate_30(
414
  pil_image_4=sd_image_d,
415
  pil_image_5=sd_image_e,
416
  prompt=prompt,
 
417
  negative_prompt=negative_prompt,
418
  text_scale=text_scale,
419
  ip_scale=ip_scale,
 
279
  "Rephrase this scene to have more elaborate details: "
280
  )
281
  input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
 
282
  print("-- got prompt --")
283
  # Encode the input text and include the attention mask
284
  encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True).to("cuda:0")
 
285
  # Ensure all values are on the correct device
286
  input_ids = encoded_inputs["input_ids"].to("cuda:0")
 
287
  attention_mask = encoded_inputs["attention_mask"].to("cuda:0")
 
288
  print("-- tokenize prompt --")
289
  # Google T5
290
  #input_ids = txt_tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
291
  outputs = model.generate(
292
  input_ids=input_ids,
293
  attention_mask=attention_mask,
294
+ max_new_tokens=1024,
295
  temperature=0.2,
296
  top_p=0.9,
297
  do_sample=True,
298
  )
299
+ enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
300
+ input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
301
+ encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
302
+ input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
303
+ attention_mask_2 = encoded_inputs_2["attention_mask"].to("cuda:0")
304
  outputs_2 = model.generate(
305
  input_ids=input_ids_2,
306
  attention_mask=attention_mask_2,
307
+ max_new_tokens=1024,
308
  temperature=0.2,
309
  top_p=0.9,
310
  do_sample=True,
311
  )
312
  # Use the encoded tensor 'text_inputs' here
 
313
  enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
314
  print('-- generated prompt --')
315
  enhanced_prompt = filter_text(enhanced_prompt,prompt)
 
404
 
405
  expand_prompt(prompt)
406
  expand_prompt(caption)
407
+ expanded = expand_prompt(caption_2)
408
 
409
  print('-- generating image --')
410
  sd_image = ip_model.generate(
 
414
  pil_image_4=sd_image_d,
415
  pil_image_5=sd_image_e,
416
  prompt=prompt,
417
+ prompt_2=expanded,
418
  negative_prompt=negative_prompt,
419
  text_scale=text_scale,
420
  ip_scale=ip_scale,