kaikaidai commited on
Commit
420bcc4
·
verified ·
1 Parent(s): b77c18b

Added Salesforce endpoint

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +51 -11
gen_api_answer.py CHANGED
@@ -23,7 +23,7 @@ together_client = Together()
23
  hf_api_key = os.getenv("HF_API_KEY")
24
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
25
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
26
-
27
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
28
  """Get response from OpenAI API"""
29
  try:
@@ -195,6 +195,36 @@ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, m
195
  except Exception as e:
196
  return f"Error with Cohere model {model_name}: {str(e)}"
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def get_model_response(
199
  model_name,
200
  model_info,
@@ -210,24 +240,25 @@ def get_model_response(
210
  api_model = model_info["api_model"]
211
  organization = model_info["organization"]
212
 
213
- # Determine if model is Prometheus or Atla or Flow Judge
214
  is_prometheus = (organization == "Prometheus")
215
  is_atla = (organization == "Atla")
216
  is_flow_judge = (organization == "Flow AI")
217
- # For non-Prometheus/Atla models/Flow Judge, use the Judge system prompt
218
- system_prompt = None if (is_prometheus or is_atla or is_flow_judge) else JUDGE_SYSTEM_PROMPT
 
 
219
 
220
  # Select the appropriate base prompt
221
-
222
- if is_atla:
223
  base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
224
  elif is_flow_judge:
225
  base_prompt = FLOW_JUDGE_PROMPT
226
  else:
227
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
228
 
229
- # For non-Prometheus/non-Atla models, replace the specific instruction
230
- if not (is_prometheus or is_atla or is_flow_judge):
231
  base_prompt = base_prompt.replace(
232
  '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
233
  '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
@@ -247,7 +278,6 @@ def get_model_response(
247
  score4_desc=prompt_data['score4_desc'],
248
  score5_desc=prompt_data['score5_desc']
249
  )
250
-
251
  else:
252
  human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
253
  ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
@@ -300,8 +330,13 @@ def get_model_response(
300
  )
301
  elif organization == "Flow AI":
302
  return get_flow_judge_response(
303
- api_model, final_prompt, # Keep default hps
304
  )
 
 
 
 
 
305
  else:
306
  # All other organizations use Together API
307
  return get_together_response(
@@ -324,7 +359,12 @@ def parse_model_response(response):
324
  data = json.loads(response)
325
  return str(data.get("result", "N/A")), data.get("feedback", "N/A")
326
  except json.JSONDecodeError:
327
- # If that fails (typically for smaller models), try to find JSON within the response
 
 
 
 
 
328
  json_match = re.search(r"{.*}", response, re.DOTALL)
329
  if json_match:
330
  data = json.loads(json_match.group(0))
 
23
  hf_api_key = os.getenv("HF_API_KEY")
24
  flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
25
  cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
26
+ salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
27
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
28
  """Get response from OpenAI API"""
29
  try:
 
195
  except Exception as e:
196
  return f"Error with Cohere model {model_name}: {str(e)}"
197
 
198
+ def get_salesforce_response(model_name, prompt, system_prompt=None, max_tokens=2048, temperature=0):
199
+ """Get response from Salesforce Research API"""
200
+ try:
201
+ headers = {
202
+ 'accept': 'application/json',
203
+ "content-type": "application/json",
204
+ "X-Api-Key": salesforce_api_key,
205
+ }
206
+
207
+ # Create messages list
208
+ messages = []
209
+ messages.append({"role": "user", "content": prompt})
210
+
211
+ json_data = {
212
+ "prompts": messages,
213
+ "temperature": temperature,
214
+ "top_p": 1,
215
+ "max_tokens": max_tokens,
216
+ }
217
+
218
+ response = requests.post(
219
+ 'https://gateway.salesforceresearch.ai/sfr-judge/process',
220
+ headers=headers,
221
+ json=json_data
222
+ )
223
+ response.raise_for_status()
224
+ return response.json()['result'][0]
225
+ except Exception as e:
226
+ return f"Error with Salesforce model {model_name}: {str(e)}"
227
+
228
  def get_model_response(
229
  model_name,
230
  model_info,
 
240
  api_model = model_info["api_model"]
241
  organization = model_info["organization"]
242
 
243
+ # Determine if model is Prometheus, Atla, Flow Judge, or Salesforce
244
  is_prometheus = (organization == "Prometheus")
245
  is_atla = (organization == "Atla")
246
  is_flow_judge = (organization == "Flow AI")
247
+ is_salesforce = (organization == "Salesforce")
248
+
249
+ # For non-Prometheus/Atla/Flow Judge/Salesforce models, use the Judge system prompt
250
+ system_prompt = None if (is_prometheus or is_atla or is_flow_judge or is_salesforce) else JUDGE_SYSTEM_PROMPT
251
 
252
  # Select the appropriate base prompt
253
+ if is_atla or is_salesforce: # Use same prompt for Atla and Salesforce
 
254
  base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
255
  elif is_flow_judge:
256
  base_prompt = FLOW_JUDGE_PROMPT
257
  else:
258
  base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
259
 
260
+ # For non-Prometheus/non-Atla/non-Salesforce models, use Prometheus but replace the output format with JSON
261
+ if not (is_prometheus or is_atla or is_flow_judge or is_salesforce):
262
  base_prompt = base_prompt.replace(
263
  '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
264
  '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
 
278
  score4_desc=prompt_data['score4_desc'],
279
  score5_desc=prompt_data['score5_desc']
280
  )
 
281
  else:
282
  human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
283
  ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
 
330
  )
331
  elif organization == "Flow AI":
332
  return get_flow_judge_response(
333
+ api_model, final_prompt
334
  )
335
+ elif organization == "Salesforce":
336
+ response = get_salesforce_response(
337
+ api_model, final_prompt, system_prompt, max_tokens, temperature
338
+ )
339
+ return response
340
  else:
341
  # All other organizations use Together API
342
  return get_together_response(
 
359
  data = json.loads(response)
360
  return str(data.get("result", "N/A")), data.get("feedback", "N/A")
361
  except json.JSONDecodeError:
362
+ # If that fails, check if this is a Salesforce response (which uses ATLA format)
363
+ if "**Reasoning:**" in response or "**Result:**" in response:
364
+ # Use ATLA parser for Salesforce responses
365
+ return atla_parse_model_response(response)
366
+
367
+ # Otherwise try to find JSON within the response
368
  json_match = re.search(r"{.*}", response, re.DOTALL)
369
  if json_match:
370
  data = json.loads(json_match.group(0))