zzz / openhands /llm /async_llm.py
ar08's picture
Upload 1040 files
246d201 verified
import asyncio
from functools import partial
from typing import Any
from litellm import acompletion as litellm_acompletion
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
from openhands.llm.llm import (
LLM,
LLM_RETRY_EXCEPTIONS,
REASONING_EFFORT_SUPPORTED_MODELS,
)
from openhands.utils.shutdown_listener import should_continue
class AsyncLLM(LLM):
"""Asynchronous LLM class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
)
async_completion_unwrapped = self._async_completion
@self.retry_decorator(
num_retries=self.config.num_retries,
retry_exceptions=LLM_RETRY_EXCEPTIONS,
retry_min_wait=self.config.retry_min_wait,
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
# litellm allows positional args, like completion(model, messages, **kwargs)
# see llm.py for more details
if len(args) > 1:
messages = args[1] if len(args) > 1 else args[0]
kwargs['messages'] = messages
# remove the first args, they're sent in kwargs
args = args[2:]
elif 'messages' in kwargs:
messages = kwargs['messages']
# Set reasoning effort for models that support it
if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS:
kwargs['reasoning_effort'] = self.config.reasoning_effort
# ensure we work with a list of messages
messages = messages if isinstance(messages, list) else [messages]
# if we have no messages, something went very wrong
if not messages:
raise ValueError(
'The messages list is empty. At least one message is required.'
)
self.log_prompt(messages)
async def check_stopped():
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
return
await asyncio.sleep(0.1)
stop_check_task = asyncio.create_task(check_stopped())
try:
# Directly call and await litellm_acompletion
resp = await async_completion_unwrapped(*args, **kwargs)
message_back = resp['choices'][0]['message']['content']
self.log_response(message_back)
# log costs and tokens used
self._post_completion(resp)
# We do not support streaming in this method, thus return resp
return resp
except UserCancelledError:
logger.debug('LLM request cancelled by user.')
raise
except Exception as e:
logger.error(f'Completion Error occurred:\n{e}')
raise
finally:
await asyncio.sleep(0.1)
stop_check_task.cancel()
try:
await stop_check_task
except asyncio.CancelledError:
pass
self._async_completion = async_completion_wrapper # type: ignore
async def _call_acompletion(self, *args, **kwargs):
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
@property
def async_completion(self):
"""Decorator for the async litellm acompletion function."""
return self._async_completion