zzz / openhands /llm /llm.py
ar08's picture
Upload 1040 files
246d201 verified
import copy
import os
import time
import warnings
from functools import partial
from typing import Any
import requests
from openhands.core.config import LLMConfig
with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
APIConnectionError,
APIError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
)
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.utils import create_pretrained_tokenizer
from openhands.core.exceptions import CloudFlareBlockageError
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.llm.debug_mixin import DebugMixin
from openhands.llm.fn_call_converter import (
STOP_WORDS,
convert_fncall_messages_to_non_fncall_messages,
convert_non_fncall_messages_to_fncall_messages,
)
from openhands.llm.metrics import Metrics
from openhands.llm.retry_mixin import RetryMixin
__all__ = ['LLM']
# tuple of exceptions to retry on
LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
APIConnectionError,
# FIXME: APIError is useful on 502 from a proxy for example,
# but it also retries on other errors that are permanent
APIError,
InternalServerError,
RateLimitError,
ServiceUnavailableError,
)
# cache prompt supporting models
# remove this when we gemini and deepseek are supported
CACHE_PROMPT_SUPPORTED_MODELS = [
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-5-haiku-20241022',
'claude-3-haiku-20240307',
'claude-3-opus-20240229',
]
# function calling supporting models
FUNCTION_CALLING_SUPPORTED_MODELS = [
'claude-3-5-sonnet',
'claude-3-5-sonnet-20240620',
'claude-3-5-sonnet-20241022',
'claude-3.5-haiku',
'claude-3-5-haiku-20241022',
'gpt-4o-mini',
'gpt-4o',
'o1-2024-12-17',
]
REASONING_EFFORT_SUPPORTED_MODELS = [
'o1-2024-12-17',
]
MODELS_WITHOUT_STOP_WORDS = [
'o1-mini',
]
class LLM(RetryMixin, DebugMixin):
"""The LLM class represents a Language Model instance.
Attributes:
config: an LLMConfig object specifying the configuration of the LLM.
"""
def __init__(
self,
config: LLMConfig,
metrics: Metrics | None = None,
):
"""Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
Passing simple parameters always overrides config.
Args:
config: The LLM configuration.
metrics: The metrics to use.
"""
self._tried_model_info = False
self.metrics: Metrics = (
metrics if metrics is not None else Metrics(model_name=config.model)
)
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_info: ModelInfo | None = None
if self.config.log_completions:
if self.config.log_completions_folder is None:
raise RuntimeError(
'log_completions_folder is required when log_completions is enabled'
)
os.makedirs(self.config.log_completions_folder, exist_ok=True)
# call init_model_info to initialize config.max_output_tokens
# which is used in partial function
with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.init_model_info()
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')
# if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
if self.config.custom_tokenizer is not None:
self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
else:
self.tokenizer = None
# set up the completion function
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
)
self._completion_unwrapped = self._completion
@self.retry_decorator(
num_retries=self.config.num_retries,
retry_exceptions=LLM_RETRY_EXCEPTIONS,
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.core.utils import json
messages: list[dict[str, Any]] | dict[str, Any] = []
mock_function_calling = kwargs.pop('mock_function_calling', False)
# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
if len(args) > 1:
# ignore the first argument if it's provided (it would be the model)
# design wise: we don't allow overriding the configured values
# implementation wise: the partial function set the model as a kwarg already
# as well as other kwargs
messages = args[1] if len(args) > 1 else args[0]
kwargs['messages'] = messages
# remove the first args, they're sent in kwargs
args = args[2:]
elif 'messages' in kwargs:
messages = kwargs['messages']
# ensure we work with a list of messages
messages = messages if isinstance(messages, list) else [messages]
original_fncall_messages = copy.deepcopy(messages)
mock_fncall_tools = None
if mock_function_calling:
assert (
'tools' in kwargs
), "'tools' must be in kwargs when mock_function_calling is True"
messages = convert_fncall_messages_to_non_fncall_messages(
messages, kwargs['tools']
)
kwargs['messages'] = messages
if self.config.model not in MODELS_WITHOUT_STOP_WORDS:
kwargs['stop'] = STOP_WORDS
mock_fncall_tools = kwargs.pop('tools')
# if we have no messages, something went very wrong
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'
)
# log the entire LLM prompt
self.log_prompt(messages)
if self.is_caching_prompt_active():
# Anthropic-specific prompt caching
if 'claude-3' in self.config.model:
kwargs['extra_headers'] = {
'anthropic-beta': 'prompt-caching-2024-07-31',
}
# Set reasoning effort for models that support it
if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS:
kwargs['reasoning_effort'] = self.config.reasoning_effort
# set litellm modify_params to the configured value
# True by default to allow litellm to do transformations like adding a default message, when a message is empty
# NOTE: this setting is global; unlike drop_params, it cannot be overridden in the litellm completion partial
litellm.modify_params = self.config.modify_params
try:
# Record start time for latency measurement
start_time = time.time()
# we don't support streaming here, thus we get a ModelResponse
resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
# Calculate and record latency
latency = time.time() - start_time
response_id = resp.get('id', 'unknown')
self.metrics.add_response_latency(latency, response_id)
non_fncall_response = copy.deepcopy(resp)
if mock_function_calling:
assert len(resp.choices) == 1
assert mock_fncall_tools is not None
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
)
resp.choices[0].message = fn_call_response_message
message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
'message'
].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
fn_args = tool_call.function.arguments
message_back += f'\nFunction call: {fn_name}({fn_args})'
# log the LLM response
self.log_response(message_back)
# post-process the response first to calculate cost
cost = self._post_completion(resp)
# log for evals or other scripts that need the raw completion
if self.config.log_completions:
assert self.config.log_completions_folder is not None
log_file = os.path.join(
self.config.log_completions_folder,
# use the metric model name (for draft editor)
f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
)
# set up the dict to be logged
_d = {
'messages': messages,
'response': resp,
'args': args,
'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
'timestamp': time.time(),
'cost': cost,
}
# if non-native function calling, save messages/response separately
if mock_function_calling:
# Overwrite response as non-fncall to be consistent with messages
_d['response'] = non_fncall_response
# Save fncall_messages/response separately
_d['fncall_messages'] = original_fncall_messages
_d['fncall_response'] = resp
with open(log_file, 'w') as f:
f.write(json.dumps(_d))
return resp
except APIError as e:
if 'Attention Required! | Cloudflare' in str(e):
raise CloudFlareBlockageError(
'Request blocked by CloudFlare'
) from e
raise
self._completion = wrapper
@property
def completion(self):
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
def init_model_info(self):
if self._tried_model_info:
return
self._tried_model_info = True
try:
if self.config.model.startswith('openrouter'):
self.model_info = litellm.get_model_info(self.config.model)
except Exception as e:
logger.debug(f'Error getting model info: {e}')
if self.config.model.startswith('litellm_proxy/'):
# IF we are using LiteLLM proxy, get model info from LiteLLM proxy
# GET {base_url}/v1/model/info with litellm_model_id as path param
response = requests.get(
f'{self.config.base_url}/v1/model/info',
headers={
'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}'
},
)
resp_json = response.json()
if 'data' not in resp_json:
logger.error(
f'Error getting model info from LiteLLM proxy: {resp_json}'
)
all_model_info = resp_json.get('data', [])
current_model_info = next(
(
info
for info in all_model_info
if info['model_name']
== self.config.model.removeprefix('litellm_proxy/')
),
None,
)
if current_model_info:
self.model_info = current_model_info['model_info']
# Last two attempts to get model info from NAME
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
pass
if not self.model_info:
try:
self.model_info = litellm.get_model_info(
self.config.model.split('/')[-1]
)
# noinspection PyBroadException
except Exception:
pass
from openhands.core.utils import json
logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')
if self.config.model.startswith('huggingface'):
# HF doesn't support the OpenAI default value for top_p (1)
logger.debug(
f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
)
self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
and isinstance(self.model_info['max_input_tokens'], int)
):
self.config.max_input_tokens = self.model_info['max_input_tokens']
else:
# Safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
if self.config.max_output_tokens is None:
# Safe default for any potentially viable model
self.config.max_output_tokens = 4096
if self.model_info is not None:
# max_output_tokens has precedence over max_tokens, if either exists.
# litellm has models with both, one or none of these 2 parameters!
if 'max_output_tokens' in self.model_info and isinstance(
self.model_info['max_output_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_output_tokens']
elif 'max_tokens' in self.model_info and isinstance(
self.model_info['max_tokens'], int
):
self.config.max_output_tokens = self.model_info['max_tokens']
def vision_is_active(self) -> bool:
with warnings.catch_warnings():
warnings.simplefilter('ignore')
return not self.config.disable_vision and self._supports_vision()
def _supports_vision(self) -> bool:
"""Acquire from litellm if model is vision capable.
Returns:
bool: True if model is vision capable. Return False if model not supported by litellm.
"""
# litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
# but model_info will have the correct value for some reason.
# we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
)
)
def is_caching_prompt_active(self) -> bool:
"""Check if prompt caching is supported and enabled for current model.
Returns:
boolean: True if prompt caching is supported and enabled for the given model.
"""
return (
self.config.caching_prompt is True
and (
self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
)
# We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint
)
def is_function_calling_active(self) -> bool:
# Check if model name is in our supported list
model_name_supported = (
self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
)
# Handle native_tool_calling user-defined configuration
if self.config.native_tool_calling is None:
return model_name_supported
elif self.config.native_tool_calling is False:
return False
else:
# try to enable native tool calling if supported by the model
supports_fn_call = litellm.supports_function_calling(
model=self.config.model
)
return supports_fn_call
def _post_completion(self, response: ModelResponse) -> float:
"""Post-process the completion response.
Logs the cost and usage stats of the completion call.
"""
try:
cur_cost = self._completion_cost(response)
except Exception:
cur_cost = 0
stats = ''
if self.cost_metric_supported:
# keep track of the cost
stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
cur_cost,
self.metrics.accumulated_cost,
)
# Add latency to stats if available
if self.metrics.response_latencies:
latest_latency = self.metrics.response_latencies[-1]
stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
usage: Usage | None = response.get('usage')
if usage:
# keep track of the input and output tokens
input_tokens = usage.get('prompt_tokens')
output_tokens = usage.get('completion_tokens')
if input_tokens:
stats += 'Input tokens: ' + str(input_tokens)
if output_tokens:
stats += (
(' | ' if input_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ '\n'
)
# read the prompt cache hit, if any
prompt_tokens_details: PromptTokensDetails = usage.get(
'prompt_tokens_details'
)
cache_hit_tokens = (
prompt_tokens_details.cached_tokens if prompt_tokens_details else None
)
if cache_hit_tokens:
stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
# For Anthropic, the cache writes have a different cost than regular input tokens
# but litellm doesn't separate them in the usage stats
# so we can read it from the provider-specific extra field
model_extra = usage.get('model_extra', {})
cache_write_tokens = model_extra.get('cache_creation_input_tokens')
if cache_write_tokens:
stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
# log the stats
if stats:
logger.debug(stats)
return cur_cost
def get_token_count(self, messages: list[dict] | list[Message]) -> int:
"""Get the number of tokens in a list of messages. Use dicts for better token counting.
Args:
messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
Returns:
int: The number of tokens.
"""
# attempt to convert Message objects to dicts, litellm expects dicts
if (
isinstance(messages, list)
and len(messages) > 0
and isinstance(messages[0], Message)
):
logger.info(
'Message objects now include serialized tool calls in token counting'
)
messages = self.format_messages_for_llm(messages) # type: ignore
# try to get the token count with the default litellm tokenizers
# or the custom tokenizer if set for this LLM configuration
try:
return litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
)
except Exception as e:
# limit logspam in case token count is not supported
logger.error(
f'Error getting token count for\n model {self.config.model}\n{e}'
+ (
f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
if self.config.custom_tokenizer is not None
else ''
)
)
return 0
def _is_local(self) -> bool:
"""Determines if the system is using a locally running LLM.
Returns:
boolean: True if executing a local model.
"""
if self.config.base_url is not None:
for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
if substring in self.config.base_url:
return True
elif self.config.model is not None:
if self.config.model.startswith('ollama'):
return True
return False
def _completion_cost(self, response) -> float:
"""Calculate the cost of a completion response based on the model. Local models are treated as free.
Add the current cost into total cost in metrics.
Args:
response: A response from a model invocation.
Returns:
number: The cost of the response.
"""
if not self.cost_metric_supported:
return 0.0
extra_kwargs = {}
if (
self.config.input_cost_per_token is not None
and self.config.output_cost_per_token is not None
):
cost_per_token = CostPerToken(
input_cost_per_token=self.config.input_cost_per_token,
output_cost_per_token=self.config.output_cost_per_token,
)
logger.debug(f'Using custom cost per token: {cost_per_token}')
extra_kwargs['custom_cost_per_token'] = cost_per_token
# try directly get response_cost from response
_hidden_params = getattr(response, '_hidden_params', {})
cost = _hidden_params.get('additional_headers', {}).get(
'llm_provider-x-litellm-response-cost', None
)
if cost is not None:
cost = float(cost)
logger.debug(f'Got response_cost from response: {cost}')
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
)
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')
if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
)
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
return str(self)
def reset(self) -> None:
self.metrics.reset()
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):
messages = [messages]
# set flags to know how to serialize the messages
for message in messages:
message.cache_enabled = self.is_caching_prompt_active()
message.vision_enabled = self.vision_is_active()
message.function_calling_enabled = self.is_function_calling_active()
if 'deepseek' in self.config.model:
message.force_string_serializer = True
# let pydantic handle the serialization
return [message.model_dump() for message in messages]