Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -279,37 +279,37 @@ def expand_prompt(prompt):
|
|
279 |
"Rephrase this scene to have more elaborate details: "
|
280 |
)
|
281 |
input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
|
282 |
-
input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
|
283 |
print("-- got prompt --")
|
284 |
# Encode the input text and include the attention mask
|
285 |
encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True).to("cuda:0")
|
286 |
-
encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
|
287 |
# Ensure all values are on the correct device
|
288 |
input_ids = encoded_inputs["input_ids"].to("cuda:0")
|
289 |
-
input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
|
290 |
attention_mask = encoded_inputs["attention_mask"].to("cuda:0")
|
291 |
-
attention_mask_2 = encoded_inputs_2["attention_mask"].to("cuda:0")
|
292 |
print("-- tokenize prompt --")
|
293 |
# Google T5
|
294 |
#input_ids = txt_tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
|
295 |
outputs = model.generate(
|
296 |
input_ids=input_ids,
|
297 |
attention_mask=attention_mask,
|
298 |
-
max_new_tokens=
|
299 |
temperature=0.2,
|
300 |
top_p=0.9,
|
301 |
do_sample=True,
|
302 |
)
|
|
|
|
|
|
|
|
|
|
|
303 |
outputs_2 = model.generate(
|
304 |
input_ids=input_ids_2,
|
305 |
attention_mask=attention_mask_2,
|
306 |
-
max_new_tokens=
|
307 |
temperature=0.2,
|
308 |
top_p=0.9,
|
309 |
do_sample=True,
|
310 |
)
|
311 |
# Use the encoded tensor 'text_inputs' here
|
312 |
-
enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
313 |
enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
|
314 |
print('-- generated prompt --')
|
315 |
enhanced_prompt = filter_text(enhanced_prompt,prompt)
|
@@ -404,7 +404,7 @@ def generate_30(
|
|
404 |
|
405 |
expand_prompt(prompt)
|
406 |
expand_prompt(caption)
|
407 |
-
expand_prompt(caption_2)
|
408 |
|
409 |
print('-- generating image --')
|
410 |
sd_image = ip_model.generate(
|
@@ -414,6 +414,7 @@ def generate_30(
|
|
414 |
pil_image_4=sd_image_d,
|
415 |
pil_image_5=sd_image_e,
|
416 |
prompt=prompt,
|
|
|
417 |
negative_prompt=negative_prompt,
|
418 |
text_scale=text_scale,
|
419 |
ip_scale=ip_scale,
|
|
|
279 |
"Rephrase this scene to have more elaborate details: "
|
280 |
)
|
281 |
input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}"
|
|
|
282 |
print("-- got prompt --")
|
283 |
# Encode the input text and include the attention mask
|
284 |
encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True).to("cuda:0")
|
|
|
285 |
# Ensure all values are on the correct device
|
286 |
input_ids = encoded_inputs["input_ids"].to("cuda:0")
|
|
|
287 |
attention_mask = encoded_inputs["attention_mask"].to("cuda:0")
|
|
|
288 |
print("-- tokenize prompt --")
|
289 |
# Google T5
|
290 |
#input_ids = txt_tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
|
291 |
outputs = model.generate(
|
292 |
input_ids=input_ids,
|
293 |
attention_mask=attention_mask,
|
294 |
+
max_new_tokens=1024,
|
295 |
temperature=0.2,
|
296 |
top_p=0.9,
|
297 |
do_sample=True,
|
298 |
)
|
299 |
+
enhanced_prompt = txt_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
300 |
+
input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {enhanced_prompt}"
|
301 |
+
encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
|
302 |
+
input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
|
303 |
+
attention_mask_2 = encoded_inputs_2["attention_mask"].to("cuda:0")
|
304 |
outputs_2 = model.generate(
|
305 |
input_ids=input_ids_2,
|
306 |
attention_mask=attention_mask_2,
|
307 |
+
max_new_tokens=1024,
|
308 |
temperature=0.2,
|
309 |
top_p=0.9,
|
310 |
do_sample=True,
|
311 |
)
|
312 |
# Use the encoded tensor 'text_inputs' here
|
|
|
313 |
enhanced_prompt_2 = txt_tokenizer.decode(outputs_2[0], skip_special_tokens=True)
|
314 |
print('-- generated prompt --')
|
315 |
enhanced_prompt = filter_text(enhanced_prompt,prompt)
|
|
|
404 |
|
405 |
expand_prompt(prompt)
|
406 |
expand_prompt(caption)
|
407 |
+
expanded = expand_prompt(caption_2)
|
408 |
|
409 |
print('-- generating image --')
|
410 |
sd_image = ip_model.generate(
|
|
|
414 |
pil_image_4=sd_image_d,
|
415 |
pil_image_5=sd_image_e,
|
416 |
prompt=prompt,
|
417 |
+
prompt_2=expanded,
|
418 |
negative_prompt=negative_prompt,
|
419 |
text_scale=text_scale,
|
420 |
ip_scale=ip_scale,
|