|
|
@@ -1,3 +1,4 @@
|
|
|
+import asyncio
|
|
|
import copy
|
|
|
import warnings
|
|
|
from functools import partial
|
|
|
@@ -13,6 +14,7 @@ from litellm.exceptions import (
|
|
|
APIConnectionError,
|
|
|
ContentPolicyViolationError,
|
|
|
InternalServerError,
|
|
|
+ OpenAIError,
|
|
|
RateLimitError,
|
|
|
ServiceUnavailableError,
|
|
|
)
|
|
|
@@ -24,6 +26,7 @@ from tenacity import (
|
|
|
wait_random_exponential,
|
|
|
)
|
|
|
|
|
|
+from opendevin.core.exceptions import UserCancelledError
|
|
|
from opendevin.core.logger import llm_prompt_logger, llm_response_logger
|
|
|
from opendevin.core.logger import opendevin_logger as logger
|
|
|
from opendevin.core.metrics import Metrics
|
|
|
@@ -56,6 +59,9 @@ class LLM:
|
|
|
self.metrics = metrics if metrics is not None else Metrics()
|
|
|
self.cost_metric_supported = True
|
|
|
|
|
|
+ # Set up config attributes with default values to prevent AttributeError
|
|
|
+ LLMConfig.set_missing_attributes(self.config)
|
|
|
+
|
|
|
# litellm actually uses base Exception here for unknown model
|
|
|
self.model_info = None
|
|
|
try:
|
|
|
@@ -66,11 +72,11 @@ class LLM:
|
|
|
self.config.model.split(':')[0]
|
|
|
)
|
|
|
# noinspection PyBroadException
|
|
|
- except Exception:
|
|
|
- logger.warning(f'Could not get model info for {config.model}')
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f'Could not get model info for {config.model}:\n{e}')
|
|
|
|
|
|
# Set the max tokens in an LM-specific way if not set
|
|
|
- if config.max_input_tokens is None:
|
|
|
+ if self.config.max_input_tokens is None:
|
|
|
if (
|
|
|
self.model_info is not None
|
|
|
and 'max_input_tokens' in self.model_info
|
|
|
@@ -81,7 +87,7 @@ class LLM:
|
|
|
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
|
|
self.config.max_input_tokens = 4096
|
|
|
|
|
|
- if config.max_output_tokens is None:
|
|
|
+ if self.config.max_output_tokens is None:
|
|
|
if (
|
|
|
self.model_info is not None
|
|
|
and 'max_output_tokens' in self.model_info
|
|
|
@@ -119,11 +125,11 @@ class LLM:
|
|
|
|
|
|
@retry(
|
|
|
reraise=True,
|
|
|
- stop=stop_after_attempt(config.num_retries),
|
|
|
+ stop=stop_after_attempt(self.config.num_retries),
|
|
|
wait=wait_random_exponential(
|
|
|
- multiplier=config.retry_multiplier,
|
|
|
- min=config.retry_min_wait,
|
|
|
- max=config.retry_max_wait,
|
|
|
+ multiplier=self.config.retry_multiplier,
|
|
|
+ min=self.config.retry_min_wait,
|
|
|
+ max=self.config.retry_max_wait,
|
|
|
),
|
|
|
retry=retry_if_exception_type(
|
|
|
(
|
|
|
@@ -147,11 +153,15 @@ class LLM:
|
|
|
# log the prompt
|
|
|
debug_message = ''
|
|
|
for message in messages:
|
|
|
- debug_message += message_separator + message['content']
|
|
|
+ if message['content'].strip():
|
|
|
+ debug_message += message_separator + message['content']
|
|
|
llm_prompt_logger.debug(debug_message)
|
|
|
|
|
|
- # call the completion function
|
|
|
- resp = completion_unwrapped(*args, **kwargs)
|
|
|
+ # skip if messages is empty (thus debug_message is empty)
|
|
|
+ if debug_message:
|
|
|
+ resp = completion_unwrapped(*args, **kwargs)
|
|
|
+ else:
|
|
|
+ resp = {'choices': [{'message': {'content': ''}}]}
|
|
|
|
|
|
# log the response
|
|
|
message_back = resp['choices'][0]['message']['content']
|
|
|
@@ -159,10 +169,191 @@ class LLM:
|
|
|
|
|
|
# post-process to log costs
|
|
|
self._post_completion(resp)
|
|
|
+
|
|
|
return resp
|
|
|
|
|
|
self._completion = wrapper # type: ignore
|
|
|
|
|
|
+ # Async version
|
|
|
+ self._async_completion = partial(
|
|
|
+ self._call_acompletion,
|
|
|
+ model=self.config.model,
|
|
|
+ api_key=self.config.api_key,
|
|
|
+ base_url=self.config.base_url,
|
|
|
+ api_version=self.config.api_version,
|
|
|
+ custom_llm_provider=self.config.custom_llm_provider,
|
|
|
+ max_tokens=self.config.max_output_tokens,
|
|
|
+ timeout=self.config.timeout,
|
|
|
+ temperature=self.config.temperature,
|
|
|
+ top_p=self.config.top_p,
|
|
|
+ drop_params=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ async_completion_unwrapped = self._async_completion
|
|
|
+
|
|
|
+ @retry(
|
|
|
+ reraise=True,
|
|
|
+ stop=stop_after_attempt(self.config.num_retries),
|
|
|
+ wait=wait_random_exponential(
|
|
|
+ multiplier=self.config.retry_multiplier,
|
|
|
+ min=self.config.retry_min_wait,
|
|
|
+ max=self.config.retry_max_wait,
|
|
|
+ ),
|
|
|
+ retry=retry_if_exception_type(
|
|
|
+ (
|
|
|
+ RateLimitError,
|
|
|
+ APIConnectionError,
|
|
|
+ ServiceUnavailableError,
|
|
|
+ InternalServerError,
|
|
|
+ ContentPolicyViolationError,
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ after=attempt_on_error,
|
|
|
+ )
|
|
|
+ async def async_completion_wrapper(*args, **kwargs):
|
|
|
+ """Async wrapper for the litellm acompletion function."""
|
|
|
+ # some callers might just send the messages directly
|
|
|
+ if 'messages' in kwargs:
|
|
|
+ messages = kwargs['messages']
|
|
|
+ else:
|
|
|
+ messages = args[1]
|
|
|
+
|
|
|
+ # log the prompt
|
|
|
+ debug_message = ''
|
|
|
+ for message in messages:
|
|
|
+ debug_message += message_separator + message['content']
|
|
|
+ llm_prompt_logger.debug(debug_message)
|
|
|
+
|
|
|
+ async def check_stopped():
|
|
|
+ while True:
|
|
|
+ if (
|
|
|
+ hasattr(self.config, 'on_cancel_requested_fn')
|
|
|
+ and self.config.on_cancel_requested_fn is not None
|
|
|
+ and await self.config.on_cancel_requested_fn()
|
|
|
+ ):
|
|
|
+ raise UserCancelledError('LLM request cancelled by user')
|
|
|
+ await asyncio.sleep(0.1)
|
|
|
+
|
|
|
+ stop_check_task = asyncio.create_task(check_stopped())
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Directly call and await litellm_acompletion
|
|
|
+ resp = await async_completion_unwrapped(*args, **kwargs)
|
|
|
+
|
|
|
+ # skip if messages is empty (thus debug_message is empty)
|
|
|
+ if debug_message:
|
|
|
+ message_back = resp['choices'][0]['message']['content']
|
|
|
+ llm_response_logger.debug(message_back)
|
|
|
+ else:
|
|
|
+ resp = {'choices': [{'message': {'content': ''}}]}
|
|
|
+ self._post_completion(resp)
|
|
|
+
|
|
|
+ # We do not support streaming in this method, thus return resp
|
|
|
+ return resp
|
|
|
+
|
|
|
+ except UserCancelledError:
|
|
|
+ logger.info('LLM request cancelled by user.')
|
|
|
+ raise
|
|
|
+ except OpenAIError as e:
|
|
|
+ logger.error(f'OpenAIError occurred:\n{e}')
|
|
|
+ raise
|
|
|
+ except (
|
|
|
+ RateLimitError,
|
|
|
+ APIConnectionError,
|
|
|
+ ServiceUnavailableError,
|
|
|
+ InternalServerError,
|
|
|
+ ) as e:
|
|
|
+ logger.error(f'Completion Error occurred:\n{e}')
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ await asyncio.sleep(0.1)
|
|
|
+ stop_check_task.cancel()
|
|
|
+ try:
|
|
|
+ await stop_check_task
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ @retry(
|
|
|
+ reraise=True,
|
|
|
+ stop=stop_after_attempt(self.config.num_retries),
|
|
|
+ wait=wait_random_exponential(
|
|
|
+ multiplier=self.config.retry_multiplier,
|
|
|
+ min=self.config.retry_min_wait,
|
|
|
+ max=self.config.retry_max_wait,
|
|
|
+ ),
|
|
|
+ retry=retry_if_exception_type(
|
|
|
+ (
|
|
|
+ RateLimitError,
|
|
|
+ APIConnectionError,
|
|
|
+ ServiceUnavailableError,
|
|
|
+ InternalServerError,
|
|
|
+ ContentPolicyViolationError,
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ after=attempt_on_error,
|
|
|
+ )
|
|
|
+ async def async_acompletion_stream_wrapper(*args, **kwargs):
|
|
|
+ """Async wrapper for the litellm acompletion with streaming function."""
|
|
|
+ # some callers might just send the messages directly
|
|
|
+ if 'messages' in kwargs:
|
|
|
+ messages = kwargs['messages']
|
|
|
+ else:
|
|
|
+ messages = args[1]
|
|
|
+
|
|
|
+ # log the prompt
|
|
|
+ debug_message = ''
|
|
|
+ for message in messages:
|
|
|
+ debug_message += message_separator + message['content']
|
|
|
+ llm_prompt_logger.debug(debug_message)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Directly call and await litellm_acompletion
|
|
|
+ resp = await async_completion_unwrapped(*args, **kwargs)
|
|
|
+
|
|
|
+ # For streaming we iterate over the chunks
|
|
|
+ async for chunk in resp:
|
|
|
+ # Check for cancellation before yielding the chunk
|
|
|
+ if (
|
|
|
+ hasattr(self.config, 'on_cancel_requested_fn')
|
|
|
+ and self.config.on_cancel_requested_fn is not None
|
|
|
+ and await self.config.on_cancel_requested_fn()
|
|
|
+ ):
|
|
|
+ raise UserCancelledError(
|
|
|
+ 'LLM request cancelled due to CANCELLED state'
|
|
|
+ )
|
|
|
+ # with streaming, it is "delta", not "message"!
|
|
|
+ message_back = chunk['choices'][0]['delta']['content']
|
|
|
+ llm_response_logger.debug(message_back)
|
|
|
+ self._post_completion(chunk)
|
|
|
+
|
|
|
+ yield chunk
|
|
|
+
|
|
|
+ except UserCancelledError:
|
|
|
+ logger.info('LLM request cancelled by user.')
|
|
|
+ raise
|
|
|
+ except OpenAIError as e:
|
|
|
+ logger.error(f'OpenAIError occurred:\n{e}')
|
|
|
+ raise
|
|
|
+ except (
|
|
|
+ RateLimitError,
|
|
|
+ APIConnectionError,
|
|
|
+ ServiceUnavailableError,
|
|
|
+ InternalServerError,
|
|
|
+ ) as e:
|
|
|
+ logger.error(f'Completion Error occurred:\n{e}')
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ if kwargs.get('stream', False):
|
|
|
+ await asyncio.sleep(0.1)
|
|
|
+
|
|
|
+ self._async_completion = async_completion_wrapper # type: ignore
|
|
|
+ self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
|
|
|
+
|
|
|
+ async def _call_acompletion(self, *args, **kwargs):
|
|
|
+ return await litellm.acompletion(*args, **kwargs)
|
|
|
+
|
|
|
@property
|
|
|
def completion(self):
|
|
|
"""Decorator for the litellm completion function.
|
|
|
@@ -171,6 +362,22 @@ class LLM:
|
|
|
"""
|
|
|
return self._completion
|
|
|
|
|
|
+ @property
|
|
|
+ def async_completion(self):
|
|
|
+ """Decorator for the async litellm acompletion function.
|
|
|
+
|
|
|
+ Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
|
|
+ """
|
|
|
+ return self._async_completion
|
|
|
+
|
|
|
+ @property
|
|
|
+ def async_streaming_completion(self):
|
|
|
+ """Decorator for the async litellm acompletion function with streaming.
|
|
|
+
|
|
|
+ Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
|
|
+ """
|
|
|
+ return self._async_streaming_completion
|
|
|
+
|
|
|
def _post_completion(self, response: str) -> None:
|
|
|
"""Post-process the completion response."""
|
|
|
try:
|