Spaces:
Running
Running
Added Salesforce endpoint
Browse files- gen_api_answer.py +51 -11
gen_api_answer.py
CHANGED
@@ -23,7 +23,7 @@ together_client = Together()
|
|
23 |
hf_api_key = os.getenv("HF_API_KEY")
|
24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
26 |
-
|
27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
28 |
"""Get response from OpenAI API"""
|
29 |
try:
|
@@ -195,6 +195,36 @@ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, m
|
|
195 |
except Exception as e:
|
196 |
return f"Error with Cohere model {model_name}: {str(e)}"
|
197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
def get_model_response(
|
199 |
model_name,
|
200 |
model_info,
|
@@ -210,24 +240,25 @@ def get_model_response(
|
|
210 |
api_model = model_info["api_model"]
|
211 |
organization = model_info["organization"]
|
212 |
|
213 |
-
# Determine if model is Prometheus
|
214 |
is_prometheus = (organization == "Prometheus")
|
215 |
is_atla = (organization == "Atla")
|
216 |
is_flow_judge = (organization == "Flow AI")
|
217 |
-
|
218 |
-
|
|
|
|
|
219 |
|
220 |
# Select the appropriate base prompt
|
221 |
-
|
222 |
-
if is_atla:
|
223 |
base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
|
224 |
elif is_flow_judge:
|
225 |
base_prompt = FLOW_JUDGE_PROMPT
|
226 |
else:
|
227 |
base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
|
228 |
|
229 |
-
# For non-Prometheus/non-Atla models, replace the
|
230 |
-
if not (is_prometheus or is_atla or is_flow_judge):
|
231 |
base_prompt = base_prompt.replace(
|
232 |
'3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
|
233 |
'3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
|
@@ -247,7 +278,6 @@ def get_model_response(
|
|
247 |
score4_desc=prompt_data['score4_desc'],
|
248 |
score5_desc=prompt_data['score5_desc']
|
249 |
)
|
250 |
-
|
251 |
else:
|
252 |
human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
|
253 |
ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
|
@@ -300,8 +330,13 @@ def get_model_response(
|
|
300 |
)
|
301 |
elif organization == "Flow AI":
|
302 |
return get_flow_judge_response(
|
303 |
-
api_model, final_prompt
|
304 |
)
|
|
|
|
|
|
|
|
|
|
|
305 |
else:
|
306 |
# All other organizations use Together API
|
307 |
return get_together_response(
|
@@ -324,7 +359,12 @@ def parse_model_response(response):
|
|
324 |
data = json.loads(response)
|
325 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
326 |
except json.JSONDecodeError:
|
327 |
-
# If that fails
|
|
|
|
|
|
|
|
|
|
|
328 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
329 |
if json_match:
|
330 |
data = json.loads(json_match.group(0))
|
|
|
23 |
hf_api_key = os.getenv("HF_API_KEY")
|
24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
26 |
+
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
28 |
"""Get response from OpenAI API"""
|
29 |
try:
|
|
|
195 |
except Exception as e:
|
196 |
return f"Error with Cohere model {model_name}: {str(e)}"
|
197 |
|
198 |
+
def get_salesforce_response(model_name, prompt, system_prompt=None, max_tokens=2048, temperature=0):
|
199 |
+
"""Get response from Salesforce Research API"""
|
200 |
+
try:
|
201 |
+
headers = {
|
202 |
+
'accept': 'application/json',
|
203 |
+
"content-type": "application/json",
|
204 |
+
"X-Api-Key": salesforce_api_key,
|
205 |
+
}
|
206 |
+
|
207 |
+
# Create messages list
|
208 |
+
messages = []
|
209 |
+
messages.append({"role": "user", "content": prompt})
|
210 |
+
|
211 |
+
json_data = {
|
212 |
+
"prompts": messages,
|
213 |
+
"temperature": temperature,
|
214 |
+
"top_p": 1,
|
215 |
+
"max_tokens": max_tokens,
|
216 |
+
}
|
217 |
+
|
218 |
+
response = requests.post(
|
219 |
+
'https://gateway.salesforceresearch.ai/sfr-judge/process',
|
220 |
+
headers=headers,
|
221 |
+
json=json_data
|
222 |
+
)
|
223 |
+
response.raise_for_status()
|
224 |
+
return response.json()['result'][0]
|
225 |
+
except Exception as e:
|
226 |
+
return f"Error with Salesforce model {model_name}: {str(e)}"
|
227 |
+
|
228 |
def get_model_response(
|
229 |
model_name,
|
230 |
model_info,
|
|
|
240 |
api_model = model_info["api_model"]
|
241 |
organization = model_info["organization"]
|
242 |
|
243 |
+
# Determine if model is Prometheus, Atla, Flow Judge, or Salesforce
|
244 |
is_prometheus = (organization == "Prometheus")
|
245 |
is_atla = (organization == "Atla")
|
246 |
is_flow_judge = (organization == "Flow AI")
|
247 |
+
is_salesforce = (organization == "Salesforce")
|
248 |
+
|
249 |
+
# For non-Prometheus/Atla/Flow Judge/Salesforce models, use the Judge system prompt
|
250 |
+
system_prompt = None if (is_prometheus or is_atla or is_flow_judge or is_salesforce) else JUDGE_SYSTEM_PROMPT
|
251 |
|
252 |
# Select the appropriate base prompt
|
253 |
+
if is_atla or is_salesforce: # Use same prompt for Atla and Salesforce
|
|
|
254 |
base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
|
255 |
elif is_flow_judge:
|
256 |
base_prompt = FLOW_JUDGE_PROMPT
|
257 |
else:
|
258 |
base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
|
259 |
|
260 |
+
# For non-Prometheus/non-Atla/non-Salesforce models, use Prometheus but replace the output format with JSON
|
261 |
+
if not (is_prometheus or is_atla or is_flow_judge or is_salesforce):
|
262 |
base_prompt = base_prompt.replace(
|
263 |
'3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
|
264 |
'3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
|
|
|
278 |
score4_desc=prompt_data['score4_desc'],
|
279 |
score5_desc=prompt_data['score5_desc']
|
280 |
)
|
|
|
281 |
else:
|
282 |
human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
|
283 |
ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
|
|
|
330 |
)
|
331 |
elif organization == "Flow AI":
|
332 |
return get_flow_judge_response(
|
333 |
+
api_model, final_prompt
|
334 |
)
|
335 |
+
elif organization == "Salesforce":
|
336 |
+
response = get_salesforce_response(
|
337 |
+
api_model, final_prompt, system_prompt, max_tokens, temperature
|
338 |
+
)
|
339 |
+
return response
|
340 |
else:
|
341 |
# All other organizations use Together API
|
342 |
return get_together_response(
|
|
|
359 |
data = json.loads(response)
|
360 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
361 |
except json.JSONDecodeError:
|
362 |
+
# If that fails, check if this is a Salesforce response (which uses ATLA format)
|
363 |
+
if "**Reasoning:**" in response or "**Result:**" in response:
|
364 |
+
# Use ATLA parser for Salesforce responses
|
365 |
+
return atla_parse_model_response(response)
|
366 |
+
|
367 |
+
# Otherwise try to find JSON within the response
|
368 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
369 |
if json_match:
|
370 |
data = json.loads(json_match.group(0))
|