SivaResearch commited on
Commit
28a48bb
·
verified ·
1 Parent(s): 6db3844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -28,19 +28,24 @@ SYSTEM_PROMPT = """<s>[INST] <<SYS>>
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
-
32
  def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True, system_prompt="System: "):
33
  formatted_text = ""
34
  for message in messages:
35
  if message["role"] == "system":
36
  formatted_text += system_prompt + message["content"] + "\n"
37
  elif message["role"] == "user":
38
- formatted_text += "\n" + message["content"] + "\n"
 
 
 
39
  elif message["role"] == "assistant":
40
- formatted_text += "\n" + message["content"].strip() + eos + "\n"
 
 
 
41
  else:
42
  raise ValueError(
43
- "Chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
44
  message["role"]
45
  )
46
  )
@@ -48,6 +53,7 @@ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True
48
  formatted_text = bos + formatted_text if add_bos else formatted_text
49
  return formatted_text
50
 
 
51
  def inference(input_prompts, model, tokenizer, system_prompt="System: "):
52
  output_texts = []
53
  model = model.to(device) # Move the model to the same device as the input data
 
28
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
 
31
  def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True, system_prompt="System: "):
32
  formatted_text = ""
33
  for message in messages:
34
  if message["role"] == "system":
35
  formatted_text += system_prompt + message["content"] + "\n"
36
  elif message["role"] == "user":
37
+ if isinstance(message["content"], list):
38
+ formatted_text += "\n" + "\n".join(message["content"]) + "\n"
39
+ else:
40
+ formatted_text += "\n" + message["content"] + "\n"
41
  elif message["role"] == "assistant":
42
+ if isinstance(message["content"], list):
43
+ formatted_text += "\n" + "\n".join(message["content"]).strip() + eos + "\n"
44
+ else:
45
+ formatted_text += "\n" + message["content"].strip() + eos + "\n"
46
  else:
47
  raise ValueError(
48
+ "Tulu chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
49
  message["role"]
50
  )
51
  )
 
53
  formatted_text = bos + formatted_text if add_bos else formatted_text
54
  return formatted_text
55
 
56
+
57
  def inference(input_prompts, model, tokenizer, system_prompt="System: "):
58
  output_texts = []
59
  model = model.to(device) # Move the model to the same device as the input data