kaikaidai commited on
Commit
0f79b0c
·
verified ·
1 Parent(s): d5ec495

Apply chat template for Atla responses

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +95 -57
gen_api_answer.py CHANGED
@@ -12,19 +12,17 @@ from prompts import (
12
  PROMETHEUS_PROMPT_WITH_REFERENCE,
13
  ATLA_PROMPT,
14
  ATLA_PROMPT_WITH_REFERENCE,
15
- FLOW_JUDGE_PROMPT
16
  )
 
17
 
18
  # Initialize clients
19
  anthropic_client = anthropic.Anthropic()
20
  openai_client = OpenAI()
21
  together_client = Together()
22
  hf_api_key = os.getenv("HF_API_KEY")
23
- cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
24
-
25
-
26
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
27
-
28
 
29
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
30
  """Get response from OpenAI API"""
@@ -73,7 +71,7 @@ def get_together_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT,
73
  except Exception as e:
74
  return f"Error with Together model {model_name}: {str(e)}"
75
 
76
- def get_prometheus_response(model_name, prompt, max_tokens=500, temperature=0.01): # temperature needs to be > 0 for hf to work
77
  """Get response from Hugging Face model"""
78
  try:
79
  headers = {
@@ -82,8 +80,19 @@ def get_prometheus_response(model_name, prompt, max_tokens=500, temperature=0.01
82
  "Content-Type": "application/json"
83
  }
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  payload = {
86
- "inputs": prompt,
87
  "parameters": {
88
  "max_new_tokens": max_tokens,
89
  "return_full_text": False,
@@ -100,7 +109,7 @@ def get_prometheus_response(model_name, prompt, max_tokens=500, temperature=0.01
100
  except Exception as e:
101
  return f"Error with Hugging Face model {model_name}: {str(e)}"
102
 
103
- def get_atla_response(model_name, prompt, max_tokens=500, temperature=0.01):
104
  """Get response from HF endpoint for Atla model"""
105
  try:
106
  headers = {
@@ -109,13 +118,25 @@ def get_atla_response(model_name, prompt, max_tokens=500, temperature=0.01):
109
  "Content-Type": "application/json"
110
  }
111
 
 
 
 
 
 
 
 
 
 
 
 
112
  payload = {
113
- "inputs": prompt,
114
  "parameters": {
115
  "max_new_tokens": max_tokens,
116
  "return_full_text": False,
117
  "temperature": temperature,
118
- "seed": 42
 
119
  }
120
  }
121
 
@@ -128,27 +149,6 @@ def get_atla_response(model_name, prompt, max_tokens=500, temperature=0.01):
128
  except Exception as e:
129
  return f"Error with Atla model {model_name}: {str(e)}"
130
 
131
- def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
132
- """Get response from Cohere API"""
133
- try:
134
- response = cohere_client.chat(
135
- model=model_name,
136
- messages=[
137
- {"role": "system", "content": system_prompt},
138
- {"role": "user", "content": prompt}
139
- ],
140
- max_tokens=max_tokens,
141
- temperature=temperature
142
- )
143
- # Extract the text from the content items
144
- content_items = response.message.content
145
- if isinstance(content_items, list):
146
- # Get the text from the first content item
147
- return content_items[0].text
148
- return str(content_items) # Fallback if it's not a list
149
- except Exception as e:
150
- return f"Error with Cohere model {model_name}: {str(e)}"
151
-
152
  def get_flow_judge_response(model_name, prompt, max_tokens=500, temperature=0.1, top_p=0.95) -> str:
153
  """Get response from Flow Judge"""
154
  try:
@@ -173,6 +173,27 @@ def get_flow_judge_response(model_name, prompt, max_tokens=500, temperature=0.1,
173
  except Exception as e:
174
  return f"Error with Flow Judge completions model {model_name}: {str(e)}"
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  def get_model_response(
177
  model_name,
178
  model_info,
@@ -188,21 +209,22 @@ def get_model_response(
188
  api_model = model_info["api_model"]
189
  organization = model_info["organization"]
190
 
191
- # Determine if model is Prometheus or Atla
192
  is_prometheus = (organization == "Prometheus")
193
  is_atla = (organization == "Atla")
194
  is_flow_judge = (organization == "Flow AI")
195
- # For non-Prometheus/Atla models, use the Judge system prompt
196
  system_prompt = None if (is_prometheus or is_atla or is_flow_judge) else JUDGE_SYSTEM_PROMPT
197
 
198
  # Select the appropriate base prompt
 
199
  if is_atla:
200
  base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
201
  elif is_flow_judge:
202
  base_prompt = FLOW_JUDGE_PROMPT
203
  else:
204
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
205
-
206
  # For non-Prometheus/non-Atla models, replace the specific instruction
207
  if not (is_prometheus or is_atla or is_flow_judge):
208
  base_prompt = base_prompt.replace(
@@ -224,6 +246,7 @@ def get_model_response(
224
  score4_desc=prompt_data['score4_desc'],
225
  score5_desc=prompt_data['score5_desc']
226
  )
 
227
  else:
228
  human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
229
  ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
@@ -249,6 +272,7 @@ def get_model_response(
249
  EVALUATION_CRITERIA=eval_criteria,
250
  RUBRIC=rubric
251
  )
 
252
  except KeyError as e:
253
  return f"Error formatting prompt: Missing required field {str(e)}"
254
 
@@ -263,11 +287,11 @@ def get_model_response(
263
  )
264
  elif organization == "Prometheus":
265
  return get_prometheus_response(
266
- api_model, final_prompt, max_tokens, temperature = 0.01
267
  )
268
  elif organization == "Atla":
269
  return get_atla_response(
270
- api_model, final_prompt, max_tokens, temperature = 0.01
271
  )
272
  elif organization == "Cohere":
273
  return get_cohere_response(
@@ -290,6 +314,10 @@ def parse_model_response(response):
290
  # Debug print
291
  print(f"Raw model response: {response}")
292
 
 
 
 
 
293
  # First try to parse the entire response as JSON
294
  try:
295
  data = json.loads(response)
@@ -306,6 +334,16 @@ def parse_model_response(response):
306
  except Exception as e:
307
  # Debug print for error case
308
  print(f"Failed to parse response: {str(e)}")
 
 
 
 
 
 
 
 
 
 
309
  return "Error", f"Failed to parse response: {response}"
310
 
311
  def prometheus_parse_model_response(output):
@@ -363,6 +401,27 @@ def prometheus_parse_model_response(output):
363
  except Exception as e:
364
  print(f"Failed to parse response: {str(e)}")
365
  return "Error", f"Exception during parsing: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
  def flow_judge_parse_model_response(output):
368
  try:
@@ -386,25 +445,4 @@ def flow_judge_parse_model_response(output):
386
 
387
  except Exception as e:
388
  print(f"Failed to parse response: {str(e)}")
389
- return "Error", f"Exception during parsing: {str(e)}"
390
-
391
- def atla_parse_model_response(output):
392
- """Parse response from ATLA model"""
393
- try:
394
- print(f"Raw Atla model response: {output}")
395
- output = output.strip()
396
-
397
- # Look for the Reasoning and Result sections
398
- reasoning_match = re.search(r'\*\*Reasoning:\*\*(.*?)(?=\*\*Result:|$)', output, re.DOTALL)
399
- result_match = re.search(r'\*\*Result:\*\*\s*(\d+)', output)
400
-
401
- if reasoning_match and result_match:
402
- feedback = reasoning_match.group(1).strip()
403
- score = result_match.group(1)
404
- return str(score), feedback
405
-
406
- return "Error", f"Failed to parse ATLA response format: {output}"
407
-
408
- except Exception as e:
409
- print(f"Failed to parse ATLA response: {str(e)}")
410
  return "Error", f"Exception during parsing: {str(e)}"
 
12
  PROMETHEUS_PROMPT_WITH_REFERENCE,
13
  ATLA_PROMPT,
14
  ATLA_PROMPT_WITH_REFERENCE,
15
+ FLOW_JUDGE_PROMPT
16
  )
17
+ from transformers import AutoTokenizer
18
 
19
  # Initialize clients
20
  anthropic_client = anthropic.Anthropic()
21
  openai_client = OpenAI()
22
  together_client = Together()
23
  hf_api_key = os.getenv("HF_API_KEY")
 
 
 
24
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
25
+ cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
26
 
27
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
28
  """Get response from OpenAI API"""
 
71
  except Exception as e:
72
  return f"Error with Together model {model_name}: {str(e)}"
73
 
74
+ def get_prometheus_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
75
  """Get response from Hugging Face model"""
76
  try:
77
  headers = {
 
80
  "Content-Type": "application/json"
81
  }
82
 
83
+ # Create messages list for chat template
84
+ messages = []
85
+ if system_prompt:
86
+ messages.append({"role": "system", "content": system_prompt})
87
+ messages.append({"role": "user", "content": prompt})
88
+
89
+ # Apply chat template
90
+ model_id = "prometheus-eval/prometheus-7b-v2.0"
91
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
92
+ formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
93
+
94
  payload = {
95
+ "inputs": formatted_prompt,
96
  "parameters": {
97
  "max_new_tokens": max_tokens,
98
  "return_full_text": False,
 
109
  except Exception as e:
110
  return f"Error with Hugging Face model {model_name}: {str(e)}"
111
 
112
+ def get_atla_response(model_name, prompt, system_prompt=None, max_tokens=500, temperature=0.01):
113
  """Get response from HF endpoint for Atla model"""
114
  try:
115
  headers = {
 
118
  "Content-Type": "application/json"
119
  }
120
 
121
+ # Create messages list for chat template
122
+ messages = []
123
+ if system_prompt:
124
+ messages.append({"role": "system", "content": system_prompt})
125
+ messages.append({"role": "user", "content": prompt})
126
+
127
+ # Apply chat template
128
+ model_id = "AtlaAI/Atla-8B-preview" # Update this if using a different model
129
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
130
+ formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
131
+
132
  payload = {
133
+ "inputs": formatted_prompt,
134
  "parameters": {
135
  "max_new_tokens": max_tokens,
136
  "return_full_text": False,
137
  "temperature": temperature,
138
+ "seed": 42,
139
+ "add_generation_prompt": True
140
  }
141
  }
142
 
 
149
  except Exception as e:
150
  return f"Error with Atla model {model_name}: {str(e)}"
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def get_flow_judge_response(model_name, prompt, max_tokens=500, temperature=0.1, top_p=0.95) -> str:
153
  """Get response from Flow Judge"""
154
  try:
 
173
  except Exception as e:
174
  return f"Error with Flow Judge completions model {model_name}: {str(e)}"
175
 
176
+ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
177
+ """Get response from Cohere API"""
178
+ try:
179
+ response = cohere_client.chat(
180
+ model=model_name,
181
+ messages=[
182
+ {"role": "system", "content": system_prompt},
183
+ {"role": "user", "content": prompt}
184
+ ],
185
+ max_tokens=max_tokens,
186
+ temperature=temperature
187
+ )
188
+ # Extract the text from the content items
189
+ content_items = response.message.content
190
+ if isinstance(content_items, list):
191
+ # Get the text from the first content item
192
+ return content_items[0].text
193
+ return str(content_items) # Fallback if it's not a list
194
+ except Exception as e:
195
+ return f"Error with Cohere model {model_name}: {str(e)}"
196
+
197
  def get_model_response(
198
  model_name,
199
  model_info,
 
209
  api_model = model_info["api_model"]
210
  organization = model_info["organization"]
211
 
212
+ # Determine if model is Prometheus or Atla or Flow Judge
213
  is_prometheus = (organization == "Prometheus")
214
  is_atla = (organization == "Atla")
215
  is_flow_judge = (organization == "Flow AI")
216
+ # For non-Prometheus/Atla models/Flow Judge, use the Judge system prompt
217
  system_prompt = None if (is_prometheus or is_atla or is_flow_judge) else JUDGE_SYSTEM_PROMPT
218
 
219
  # Select the appropriate base prompt
220
+
221
  if is_atla:
222
  base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
223
  elif is_flow_judge:
224
  base_prompt = FLOW_JUDGE_PROMPT
225
  else:
226
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
227
+
228
  # For non-Prometheus/non-Atla models, replace the specific instruction
229
  if not (is_prometheus or is_atla or is_flow_judge):
230
  base_prompt = base_prompt.replace(
 
246
  score4_desc=prompt_data['score4_desc'],
247
  score5_desc=prompt_data['score5_desc']
248
  )
249
+
250
  else:
251
  human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
252
  ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
 
272
  EVALUATION_CRITERIA=eval_criteria,
273
  RUBRIC=rubric
274
  )
275
+
276
  except KeyError as e:
277
  return f"Error formatting prompt: Missing required field {str(e)}"
278
 
 
287
  )
288
  elif organization == "Prometheus":
289
  return get_prometheus_response(
290
+ api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
291
  )
292
  elif organization == "Atla":
293
  return get_atla_response(
294
+ api_model, final_prompt, system_prompt, max_tokens, temperature = 0.01
295
  )
296
  elif organization == "Cohere":
297
  return get_cohere_response(
 
314
  # Debug print
315
  print(f"Raw model response: {response}")
316
 
317
+ # If response is already a dictionary, use it directly
318
+ if isinstance(response, dict):
319
+ return str(response.get("result", "N/A")), response.get("feedback", "N/A")
320
+
321
  # First try to parse the entire response as JSON
322
  try:
323
  data = json.loads(response)
 
334
  except Exception as e:
335
  # Debug print for error case
336
  print(f"Failed to parse response: {str(e)}")
337
+
338
+ # If the error message itself contains valid JSON, try to parse that
339
+ try:
340
+ error_json_match = re.search(r"{.*}", str(e), re.DOTALL)
341
+ if error_json_match:
342
+ data = json.loads(error_json_match.group(0))
343
+ return str(data.get("result", "N/A")), data.get("feedback", "N/A")
344
+ except:
345
+ pass
346
+
347
  return "Error", f"Failed to parse response: {response}"
348
 
349
  def prometheus_parse_model_response(output):
 
401
  except Exception as e:
402
  print(f"Failed to parse response: {str(e)}")
403
  return "Error", f"Exception during parsing: {str(e)}"
404
+
405
+ def atla_parse_model_response(output):
406
+ """Parse response from ATLA model"""
407
+ try:
408
+ print(f"Raw Atla model response: {output}")
409
+ output = output.strip()
410
+
411
+ # Look for the Reasoning and Result sections
412
+ reasoning_match = re.search(r'\*\*Reasoning:\*\*(.*?)(?=\*\*Result:|$)', output, re.DOTALL)
413
+ result_match = re.search(r'\*\*Result:\*\*\s*(\d+)', output)
414
+
415
+ if reasoning_match and result_match:
416
+ feedback = reasoning_match.group(1).strip()
417
+ score = result_match.group(1)
418
+ return str(score), feedback
419
+
420
+ return "Error", f"Failed to parse ATLA response format: {output}"
421
+
422
+ except Exception as e:
423
+ print(f"Failed to parse ATLA response: {str(e)}")
424
+ return "Error", f"Exception during parsing: {str(e)}"
425
 
426
  def flow_judge_parse_model_response(output):
427
  try:
 
445
 
446
  except Exception as e:
447
  print(f"Failed to parse response: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  return "Error", f"Exception during parsing: {str(e)}"