1inkusFace commited on
Commit
f40a3f9
·
verified ·
1 Parent(s): 6fbeb7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -280,8 +280,8 @@ def expand_prompt(prompt):
280
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
281
  print("-- got prompt --")
282
  # Encode the input text and include the attention mask
283
- encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True)
284
- encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True)
285
  # Ensure all values are on the correct device
286
  input_ids = encoded_inputs["input_ids"].to("cuda:0")
287
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")
 
280
  input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}"
281
  print("-- got prompt --")
282
  # Encode the input text and include the attention mask
283
+ encoded_inputs = txt_tokenizer(input_text, return_tensors="pt", return_attention_mask=True).to("cuda:0")
284
+ encoded_inputs_2 = txt_tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True).to("cuda:0")
285
  # Ensure all values are on the correct device
286
  input_ids = encoded_inputs["input_ids"].to("cuda:0")
287
  input_ids_2 = encoded_inputs_2["input_ids"].to("cuda:0")