Spaces:
Sleeping
Sleeping
lukestanley
commited on
Commit
·
1e622b4
1
Parent(s):
135f3ac
Add retry logic upon schema fail for Mistral API calls
Browse files
utils.py
CHANGED
@@ -190,7 +190,7 @@ def llm_stream_serverless(prompt,model):
|
|
190 |
LAST_REQUEST_TIME = None
|
191 |
REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
|
192 |
|
193 |
-
def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict[str, Any]]:
|
194 |
global LAST_REQUEST_TIME
|
195 |
current_time = time()
|
196 |
if LAST_REQUEST_TIME is not None:
|
@@ -227,10 +227,24 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class) -> Union[str, Dict
|
|
227 |
print(result)
|
228 |
output = result['choices'][0]['message']['content']
|
229 |
if pydantic_model_class:
|
230 |
-
|
231 |
-
|
232 |
-
#
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
else:
|
235 |
print("No pydantic model class provided, returning without class validation")
|
236 |
return json.loads(output)
|
|
|
190 |
LAST_REQUEST_TIME = None
|
191 |
REQUEST_INTERVAL = 0.5 # Minimum time interval between requests in seconds
|
192 |
|
193 |
+
def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
|
194 |
global LAST_REQUEST_TIME
|
195 |
current_time = time()
|
196 |
if LAST_REQUEST_TIME is not None:
|
|
|
227 |
print(result)
|
228 |
output = result['choices'][0]['message']['content']
|
229 |
if pydantic_model_class:
|
230 |
+
# TODO: Use more robust error handling that works for all cases without retrying?
|
231 |
+
# Maybe APIs that dont have grammar should be avoided?
|
232 |
+
# Investigate grammar enforcement with open ended generations?
|
233 |
+
try:
|
234 |
+
parsed_result = pydantic_model_class.model_validate_json(output)
|
235 |
+
print(parsed_result)
|
236 |
+
# This will raise an exception if the model is invalid,
|
237 |
+
except Exception as e:
|
238 |
+
print(f"Error validating pydantic model: {e}")
|
239 |
+
# Let's retry by calling ourselves again if attempts < 3
|
240 |
+
if attempts == 0:
|
241 |
+
# We modify the prompt to remind it to output JSON in the required format
|
242 |
+
prompt = f"{prompt} You must output the JSON in the required format!"
|
243 |
+
if attempts < 3:
|
244 |
+
attempts += 1
|
245 |
+
print(f"Retrying Mistral API call, attempt {attempts}")
|
246 |
+
return llm_stream_mistral_api(prompt, pydantic_model_class, attempts)
|
247 |
+
|
248 |
else:
|
249 |
print("No pydantic model class provided, returning without class validation")
|
250 |
return json.loads(output)
|