|
import asyncio
|
|
from functools import partial
|
|
from typing import Any
|
|
|
|
from litellm import acompletion as litellm_acompletion
|
|
|
|
from openhands.core.exceptions import UserCancelledError
|
|
from openhands.core.logger import openhands_logger as logger
|
|
from openhands.llm.llm import (
|
|
LLM,
|
|
LLM_RETRY_EXCEPTIONS,
|
|
REASONING_EFFORT_SUPPORTED_MODELS,
|
|
)
|
|
from openhands.utils.shutdown_listener import should_continue
|
|
|
|
|
|
class AsyncLLM(LLM):
|
|
"""Asynchronous LLM class."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._async_completion = partial(
|
|
self._call_acompletion,
|
|
model=self.config.model,
|
|
api_key=self.config.api_key.get_secret_value()
|
|
if self.config.api_key
|
|
else None,
|
|
base_url=self.config.base_url,
|
|
api_version=self.config.api_version,
|
|
custom_llm_provider=self.config.custom_llm_provider,
|
|
max_tokens=self.config.max_output_tokens,
|
|
timeout=self.config.timeout,
|
|
temperature=self.config.temperature,
|
|
top_p=self.config.top_p,
|
|
drop_params=self.config.drop_params,
|
|
)
|
|
|
|
async_completion_unwrapped = self._async_completion
|
|
|
|
@self.retry_decorator(
|
|
num_retries=self.config.num_retries,
|
|
retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
|
retry_min_wait=self.config.retry_min_wait,
|
|
retry_max_wait=self.config.retry_max_wait,
|
|
retry_multiplier=self.config.retry_multiplier,
|
|
)
|
|
async def async_completion_wrapper(*args, **kwargs):
|
|
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
|
|
messages: list[dict[str, Any]] | dict[str, Any] = []
|
|
|
|
|
|
|
|
|
|
if len(args) > 1:
|
|
messages = args[1] if len(args) > 1 else args[0]
|
|
kwargs['messages'] = messages
|
|
|
|
|
|
args = args[2:]
|
|
elif 'messages' in kwargs:
|
|
messages = kwargs['messages']
|
|
|
|
|
|
if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS:
|
|
kwargs['reasoning_effort'] = self.config.reasoning_effort
|
|
|
|
|
|
messages = messages if isinstance(messages, list) else [messages]
|
|
|
|
|
|
if not messages:
|
|
raise ValueError(
|
|
'The messages list is empty. At least one message is required.'
|
|
)
|
|
|
|
self.log_prompt(messages)
|
|
|
|
async def check_stopped():
|
|
while should_continue():
|
|
if (
|
|
hasattr(self.config, 'on_cancel_requested_fn')
|
|
and self.config.on_cancel_requested_fn is not None
|
|
and await self.config.on_cancel_requested_fn()
|
|
):
|
|
return
|
|
await asyncio.sleep(0.1)
|
|
|
|
stop_check_task = asyncio.create_task(check_stopped())
|
|
|
|
try:
|
|
|
|
resp = await async_completion_unwrapped(*args, **kwargs)
|
|
|
|
message_back = resp['choices'][0]['message']['content']
|
|
self.log_response(message_back)
|
|
|
|
|
|
self._post_completion(resp)
|
|
|
|
|
|
return resp
|
|
|
|
except UserCancelledError:
|
|
logger.debug('LLM request cancelled by user.')
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f'Completion Error occurred:\n{e}')
|
|
raise
|
|
|
|
finally:
|
|
await asyncio.sleep(0.1)
|
|
stop_check_task.cancel()
|
|
try:
|
|
await stop_check_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
self._async_completion = async_completion_wrapper
|
|
|
|
async def _call_acompletion(self, *args, **kwargs):
|
|
"""Wrapper for the litellm acompletion function."""
|
|
|
|
return await litellm_acompletion(*args, **kwargs)
|
|
|
|
@property
|
|
def async_completion(self):
|
|
"""Decorator for the async litellm acompletion function."""
|
|
return self._async_completion
|
|
|