Spaces:
Running
Running
""" | |
This module contains functions to interact with the models. | |
""" | |
import json | |
import os | |
from typing import List, Optional, Tuple | |
import litellm | |
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the given text without changing the language of it." # pylint: disable=line-too-long | |
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the given text from {source_lang} to {target_lang}." # pylint: disable=line-too-long | |
class ContextWindowExceededError(Exception): | |
pass | |
class Model: | |
def __init__( | |
self, | |
name: str, | |
provider: str = None, | |
api_key: str = None, | |
api_base: str = None, | |
summarize_instruction: str = None, | |
translate_instruction: str = None, | |
): | |
self.name = name | |
self.provider = provider | |
self.api_key = api_key | |
self.api_base = api_base | |
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long | |
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long | |
# Returns the parsed result or raw response, and whether parsing succeeded. | |
def completion(self, | |
instruction: str, | |
prompt: str, | |
max_tokens: Optional[float] = None, | |
max_retries: int = 2) -> Tuple[str, bool]: | |
messages = [{ | |
"role": | |
"system", | |
"content": | |
instruction + """ | |
Output following this JSON format without using code blocks: | |
{"result": "your result here"}""" | |
}, { | |
"role": "user", | |
"content": prompt | |
}] | |
for attempt in range(max_retries + 1): | |
try: | |
response = litellm.completion(model=self.provider + "/" + | |
self.name if self.provider else self.name, | |
api_key=self.api_key, | |
api_base=self.api_base, | |
messages=messages, | |
max_tokens=max_tokens, | |
**self._get_completion_kwargs()) | |
json_response = response.choices[0].message.content | |
parsed_json = json.loads(json_response) | |
return parsed_json["result"], True | |
except litellm.ContextWindowExceededError as e: | |
raise ContextWindowExceededError() from e | |
except json.JSONDecodeError: | |
if attempt == max_retries: | |
return json_response, False | |
def _get_completion_kwargs(self): | |
return { | |
# Ref: https://litellm.vercel.app/docs/completion/input#optional-fields # pylint: disable=line-too-long | |
"response_format": { | |
"type": "json_object" | |
} | |
} | |
class AnthropicModel(Model): | |
def completion(self, | |
instruction: str, | |
prompt: str, | |
max_tokens: Optional[float] = None, | |
max_retries: int = 2) -> Tuple[str, bool]: | |
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long | |
prefix = "<result>" | |
suffix = "</result>" | |
messages = [{ | |
"role": | |
"user", | |
"content": | |
f"""{instruction} | |
Output following this format: | |
{prefix}...{suffix} | |
Text: | |
{prompt}""" | |
}, { | |
"role": "assistant", | |
"content": prefix | |
}] | |
for attempt in range(max_retries + 1): | |
try: | |
response = litellm.completion( | |
model=self.provider + "/" + | |
self.name if self.provider else self.name, | |
api_key=self.api_key, | |
api_base=self.api_base, | |
messages=messages, | |
max_tokens=max_tokens, | |
) | |
except litellm.ContextWindowExceededError as e: | |
raise ContextWindowExceededError() from e | |
result = response.choices[0].message.content | |
if result.endswith(suffix): | |
return result.removesuffix(suffix).strip(), True | |
if attempt == max_retries: | |
return result, False | |
class VertexModel(Model): | |
def __init__(self, name: str, vertex_credentials: str): | |
super().__init__(name, provider="vertex_ai") | |
self.vertex_credentials = vertex_credentials | |
def _get_completion_kwargs(self): | |
return { | |
"response_format": { | |
"type": "json_object" | |
}, | |
"vertex_credentials": self.vertex_credentials | |
} | |
supported_models: List[Model] = [ | |
Model("gpt-4o-2024-08-06"), | |
Model("gpt-4o-mini-2024-07-18"), | |
AnthropicModel("claude-3-5-sonnet-20240620"), | |
VertexModel("gemini-1.5-pro-001", | |
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")), | |
VertexModel("gemini-1.5-flash-preview-0514", | |
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")), | |
Model("meta-llama/Meta-Llama-3.1-8B-Instruct", provider="deepinfra"), | |
Model("meta-llama/Meta-Llama-3.1-70B-Instruct", provider="deepinfra"), | |
Model("meta-llama/Meta-Llama-3.1-405B-Instruct", provider="deepinfra"), | |
Model("Qwen/Qwen2.5-72B-Instruct", provider="deepinfra"), | |
Model("Qwen/Qwen2-72B-Instruct", provider="deepinfra"), | |
Model("google/gemma-2-9b-it", provider="deepinfra"), | |
Model("google/gemma-2-27b-it", provider="deepinfra"), | |
] | |
def check_models(models: List[Model]): | |
for model in models: | |
print(f"Checking model {model.name}...") | |
try: | |
model.completion( | |
"""Output following this JSON format without using code blocks: | |
{"result": "your result here"}""", "How are you?") | |
print(f"Model {model.name} is available.") | |
# This check is designed to verify the availability of the models | |
# without any issues. Therefore, we need to catch all exceptions. | |
except Exception as e: # pylint: disable=broad-except | |
raise RuntimeError(f"Model {model.name} is not available: {e}") from e | |