Просмотр исходного кода

(feat) LLM class: added acompletion and streaming + unit test (#3202)

* LLM class: added acompletion and streaming, unit test test_acompletion.py

* LLM: cleanup of self.config defaults and their use

* added set_missing_attributes to LLMConfig

* move default checker up
tobitege 1 год назад
Родитель
Сommit
a4cb880699
4 измененных файлов с 416 добавлено и 11 удалено
  1. 6 0
      opendevin/core/config.py
  2. 5 0
      opendevin/core/exceptions.py
  3. 218 11
      opendevin/llm/llm.py
  4. 187 0
      tests/unit/test_acompletion.py

+ 6 - 0
opendevin/core/config.py

@@ -111,6 +111,12 @@ class LLMConfig:
                 ret[k] = '******' if v else None
         return ret
 
+    def set_missing_attributes(self):
+        """Set any missing attributes to their default values."""
+        for field_name, field_obj in self.__dataclass_fields__.items():
+            if not hasattr(self, field_name):
+                setattr(self, field_name, field_obj.default)
+
 
 @dataclass
 class AgentConfig:

+ 5 - 0
opendevin/core/exceptions.py

@@ -67,3 +67,8 @@ class LLMNoActionError(Exception):
 class LLMResponseError(Exception):
     def __init__(self, message='Failed to retrieve action from LLM response'):
         super().__init__(message)
+
+
+class UserCancelledError(Exception):
+    def __init__(self, message='User cancelled the request'):
+        super().__init__(message)

+ 218 - 11
opendevin/llm/llm.py

@@ -1,3 +1,4 @@
+import asyncio
 import copy
 import warnings
 from functools import partial
@@ -13,6 +14,7 @@ from litellm.exceptions import (
     APIConnectionError,
     ContentPolicyViolationError,
     InternalServerError,
+    OpenAIError,
     RateLimitError,
     ServiceUnavailableError,
 )
@@ -24,6 +26,7 @@ from tenacity import (
     wait_random_exponential,
 )
 
+from opendevin.core.exceptions import UserCancelledError
 from opendevin.core.logger import llm_prompt_logger, llm_response_logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.metrics import Metrics
@@ -56,6 +59,9 @@ class LLM:
         self.metrics = metrics if metrics is not None else Metrics()
         self.cost_metric_supported = True
 
+        # Set up config attributes with default values to prevent AttributeError
+        LLMConfig.set_missing_attributes(self.config)
+
         # litellm actually uses base Exception here for unknown model
         self.model_info = None
         try:
@@ -66,11 +72,11 @@ class LLM:
                     self.config.model.split(':')[0]
                 )
         # noinspection PyBroadException
-        except Exception:
-            logger.warning(f'Could not get model info for {config.model}')
+        except Exception as e:
+            logger.warning(f'Could not get model info for {config.model}:\n{e}')
 
         # Set the max tokens in an LM-specific way if not set
-        if config.max_input_tokens is None:
+        if self.config.max_input_tokens is None:
             if (
                 self.model_info is not None
                 and 'max_input_tokens' in self.model_info
@@ -81,7 +87,7 @@ class LLM:
                 # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
                 self.config.max_input_tokens = 4096
 
-        if config.max_output_tokens is None:
+        if self.config.max_output_tokens is None:
             if (
                 self.model_info is not None
                 and 'max_output_tokens' in self.model_info
@@ -119,11 +125,11 @@ class LLM:
 
         @retry(
             reraise=True,
-            stop=stop_after_attempt(config.num_retries),
+            stop=stop_after_attempt(self.config.num_retries),
             wait=wait_random_exponential(
-                multiplier=config.retry_multiplier,
-                min=config.retry_min_wait,
-                max=config.retry_max_wait,
+                multiplier=self.config.retry_multiplier,
+                min=self.config.retry_min_wait,
+                max=self.config.retry_max_wait,
             ),
             retry=retry_if_exception_type(
                 (
@@ -147,11 +153,15 @@ class LLM:
             # log the prompt
             debug_message = ''
             for message in messages:
-                debug_message += message_separator + message['content']
+                if message['content'].strip():
+                    debug_message += message_separator + message['content']
             llm_prompt_logger.debug(debug_message)
 
-            # call the completion function
-            resp = completion_unwrapped(*args, **kwargs)
+            # skip if messages is empty (thus debug_message is empty)
+            if debug_message:
+                resp = completion_unwrapped(*args, **kwargs)
+            else:
+                resp = {'choices': [{'message': {'content': ''}}]}
 
             # log the response
             message_back = resp['choices'][0]['message']['content']
@@ -159,10 +169,191 @@ class LLM:
 
             # post-process to log costs
             self._post_completion(resp)
+
             return resp
 
         self._completion = wrapper  # type: ignore
 
+        # Async version
+        self._async_completion = partial(
+            self._call_acompletion,
+            model=self.config.model,
+            api_key=self.config.api_key,
+            base_url=self.config.base_url,
+            api_version=self.config.api_version,
+            custom_llm_provider=self.config.custom_llm_provider,
+            max_tokens=self.config.max_output_tokens,
+            timeout=self.config.timeout,
+            temperature=self.config.temperature,
+            top_p=self.config.top_p,
+            drop_params=True,
+        )
+
+        async_completion_unwrapped = self._async_completion
+
+        @retry(
+            reraise=True,
+            stop=stop_after_attempt(self.config.num_retries),
+            wait=wait_random_exponential(
+                multiplier=self.config.retry_multiplier,
+                min=self.config.retry_min_wait,
+                max=self.config.retry_max_wait,
+            ),
+            retry=retry_if_exception_type(
+                (
+                    RateLimitError,
+                    APIConnectionError,
+                    ServiceUnavailableError,
+                    InternalServerError,
+                    ContentPolicyViolationError,
+                )
+            ),
+            after=attempt_on_error,
+        )
+        async def async_completion_wrapper(*args, **kwargs):
+            """Async wrapper for the litellm acompletion function."""
+            # some callers might just send the messages directly
+            if 'messages' in kwargs:
+                messages = kwargs['messages']
+            else:
+                messages = args[1]
+
+            # log the prompt
+            debug_message = ''
+            for message in messages:
+                debug_message += message_separator + message['content']
+            llm_prompt_logger.debug(debug_message)
+
+            async def check_stopped():
+                while True:
+                    if (
+                        hasattr(self.config, 'on_cancel_requested_fn')
+                        and self.config.on_cancel_requested_fn is not None
+                        and await self.config.on_cancel_requested_fn()
+                    ):
+                        raise UserCancelledError('LLM request cancelled by user')
+                    await asyncio.sleep(0.1)
+
+            stop_check_task = asyncio.create_task(check_stopped())
+
+            try:
+                # Directly call and await litellm_acompletion
+                resp = await async_completion_unwrapped(*args, **kwargs)
+
+                # skip if messages is empty (thus debug_message is empty)
+                if debug_message:
+                    message_back = resp['choices'][0]['message']['content']
+                    llm_response_logger.debug(message_back)
+                else:
+                    resp = {'choices': [{'message': {'content': ''}}]}
+                self._post_completion(resp)
+
+                # We do not support streaming in this method, thus return resp
+                return resp
+
+            except UserCancelledError:
+                logger.info('LLM request cancelled by user.')
+                raise
+            except OpenAIError as e:
+                logger.error(f'OpenAIError occurred:\n{e}')
+                raise
+            except (
+                RateLimitError,
+                APIConnectionError,
+                ServiceUnavailableError,
+                InternalServerError,
+            ) as e:
+                logger.error(f'Completion Error occurred:\n{e}')
+                raise
+
+            finally:
+                await asyncio.sleep(0.1)
+                stop_check_task.cancel()
+                try:
+                    await stop_check_task
+                except asyncio.CancelledError:
+                    pass
+
+        @retry(
+            reraise=True,
+            stop=stop_after_attempt(self.config.num_retries),
+            wait=wait_random_exponential(
+                multiplier=self.config.retry_multiplier,
+                min=self.config.retry_min_wait,
+                max=self.config.retry_max_wait,
+            ),
+            retry=retry_if_exception_type(
+                (
+                    RateLimitError,
+                    APIConnectionError,
+                    ServiceUnavailableError,
+                    InternalServerError,
+                    ContentPolicyViolationError,
+                )
+            ),
+            after=attempt_on_error,
+        )
+        async def async_acompletion_stream_wrapper(*args, **kwargs):
+            """Async wrapper for the litellm acompletion with streaming function."""
+            # some callers might just send the messages directly
+            if 'messages' in kwargs:
+                messages = kwargs['messages']
+            else:
+                messages = args[1]
+
+            # log the prompt
+            debug_message = ''
+            for message in messages:
+                debug_message += message_separator + message['content']
+            llm_prompt_logger.debug(debug_message)
+
+            try:
+                # Directly call and await litellm_acompletion
+                resp = await async_completion_unwrapped(*args, **kwargs)
+
+                # For streaming we iterate over the chunks
+                async for chunk in resp:
+                    # Check for cancellation before yielding the chunk
+                    if (
+                        hasattr(self.config, 'on_cancel_requested_fn')
+                        and self.config.on_cancel_requested_fn is not None
+                        and await self.config.on_cancel_requested_fn()
+                    ):
+                        raise UserCancelledError(
+                            'LLM request cancelled due to CANCELLED state'
+                        )
+                    # with streaming, it is "delta", not "message"!
+                    message_back = chunk['choices'][0]['delta']['content']
+                    llm_response_logger.debug(message_back)
+                    self._post_completion(chunk)
+
+                    yield chunk
+
+            except UserCancelledError:
+                logger.info('LLM request cancelled by user.')
+                raise
+            except OpenAIError as e:
+                logger.error(f'OpenAIError occurred:\n{e}')
+                raise
+            except (
+                RateLimitError,
+                APIConnectionError,
+                ServiceUnavailableError,
+                InternalServerError,
+            ) as e:
+                logger.error(f'Completion Error occurred:\n{e}')
+                raise
+
+            finally:
+                if kwargs.get('stream', False):
+                    await asyncio.sleep(0.1)
+
+        self._async_completion = async_completion_wrapper  # type: ignore
+        self._async_streaming_completion = async_acompletion_stream_wrapper  # type: ignore
+
+    async def _call_acompletion(self, *args, **kwargs):
+        return await litellm.acompletion(*args, **kwargs)
+
     @property
     def completion(self):
         """Decorator for the litellm completion function.
@@ -171,6 +362,22 @@ class LLM:
         """
         return self._completion
 
+    @property
+    def async_completion(self):
+        """Decorator for the async litellm acompletion function.
+
+        Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
+        """
+        return self._async_completion
+
+    @property
+    def async_streaming_completion(self):
+        """Decorator for the async litellm acompletion function with streaming.
+
+        Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
+        """
+        return self._async_streaming_completion
+
     def _post_completion(self, response: str) -> None:
         """Post-process the completion response."""
         try:

+ 187 - 0
tests/unit/test_acompletion.py

@@ -0,0 +1,187 @@
+import asyncio
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from opendevin.core.config import load_app_config
+from opendevin.core.exceptions import UserCancelledError
+from opendevin.llm.llm import LLM
+
+config = load_app_config()
+
+
+@pytest.fixture
+def test_llm():
+    # Create a mock config for testing
+    return LLM(config=config.get_llm_config())
+
+
+@pytest.fixture
+def mock_response():
+    return [
+        {'choices': [{'delta': {'content': 'This is a'}}]},
+        {'choices': [{'delta': {'content': ' test'}}]},
+        {'choices': [{'delta': {'content': ' message.'}}]},
+        {'choices': [{'delta': {'content': ' It is'}}]},
+        {'choices': [{'delta': {'content': ' a bit'}}]},
+        {'choices': [{'delta': {'content': ' longer'}}]},
+        {'choices': [{'delta': {'content': ' than'}}]},
+        {'choices': [{'delta': {'content': ' the'}}]},
+        {'choices': [{'delta': {'content': ' previous'}}]},
+        {'choices': [{'delta': {'content': ' one,'}}]},
+        {'choices': [{'delta': {'content': ' but'}}]},
+        {'choices': [{'delta': {'content': ' hopefully'}}]},
+        {'choices': [{'delta': {'content': ' still'}}]},
+        {'choices': [{'delta': {'content': ' short'}}]},
+        {'choices': [{'delta': {'content': ' enough.'}}]},
+    ]
+
+
+@pytest.mark.asyncio
+async def test_acompletion_non_streaming():
+    with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
+        mock_response = {
+            'choices': [{'message': {'content': 'This is a test message.'}}]
+        }
+        mock_call_acompletion.return_value = mock_response
+        test_llm = LLM(config=config.get_llm_config())
+        response = await test_llm.async_completion(
+            messages=[{'role': 'user', 'content': 'Hello!'}],
+            stream=False,
+            drop_params=True,
+        )
+        # Assertions for non-streaming completion
+        assert response['choices'][0]['message']['content'] != ''
+
+
+@pytest.mark.asyncio
+async def test_acompletion_streaming(mock_response):
+    with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
+        mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
+        test_llm = LLM(config=config.get_llm_config())
+        async for chunk in test_llm.async_streaming_completion(
+            messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
+        ):
+            print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
+            # Assertions for streaming completion
+            assert chunk['choices'][0]['delta']['content'] in [
+                r['choices'][0]['delta']['content'] for r in mock_response
+            ]
+
+
+@pytest.mark.asyncio
+async def test_completion(test_llm):
+    with patch.object(LLM, 'completion') as mock_completion:
+        mock_completion.return_value = {
+            'choices': [{'message': {'content': 'This is a test message.'}}]
+        }
+        response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
+        assert response['choices'][0]['message']['content'] == 'This is a test message.'
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9])
+async def test_async_completion_with_user_cancellation(cancel_delay):
+    cancel_event = asyncio.Event()
+
+    async def mock_on_cancel_requested():
+        is_set = cancel_event.is_set()
+        print(f'Cancel requested: {is_set}')
+        return is_set
+
+    config = load_app_config()
+    config.on_cancel_requested_fn = mock_on_cancel_requested
+
+    async def mock_acompletion(*args, **kwargs):
+        print('Starting mock_acompletion')
+        for i in range(20):  # Increased iterations for longer running task
+            print(f'mock_acompletion iteration {i}')
+            await asyncio.sleep(0.1)
+            if await mock_on_cancel_requested():
+                print('Cancellation detected in mock_acompletion')
+                raise UserCancelledError('LLM request cancelled by user')
+        print('Completing mock_acompletion without cancellation')
+        return {'choices': [{'message': {'content': 'This is a test message.'}}]}
+
+    with patch.object(
+        LLM, '_call_acompletion', new_callable=AsyncMock
+    ) as mock_call_acompletion:
+        mock_call_acompletion.side_effect = mock_acompletion
+        test_llm = LLM(config=config.get_llm_config())
+
+        async def cancel_after_delay():
+            print(f'Starting cancel_after_delay with delay {cancel_delay}')
+            await asyncio.sleep(cancel_delay)
+            print('Setting cancel event')
+            cancel_event.set()
+
+        with pytest.raises(UserCancelledError):
+            await asyncio.gather(
+                test_llm.async_completion(
+                    messages=[{'role': 'user', 'content': 'Hello!'}],
+                    stream=False,
+                ),
+                cancel_after_delay(),
+            )
+
+    # Ensure the mock was called
+    mock_call_acompletion.assert_called_once()
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9])
+async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
+    cancel_requested = False
+
+    async def mock_on_cancel_requested():
+        nonlocal cancel_requested
+        return cancel_requested
+
+    config = load_app_config()
+    config.on_cancel_requested_fn = mock_on_cancel_requested
+
+    test_messages = [
+        'This is ',
+        'a test ',
+        'message ',
+        'with ',
+        'multiple ',
+        'chunks ',
+        'to ',
+        'simulate ',
+        'a ',
+        'longer ',
+        'streaming ',
+        'response.',
+    ]
+
+    async def mock_acompletion(*args, **kwargs):
+        for i, content in enumerate(test_messages):
+            yield {'choices': [{'delta': {'content': content}}]}
+            if i + 1 == cancel_after_chunks:
+                nonlocal cancel_requested
+                cancel_requested = True
+            if cancel_requested:
+                raise UserCancelledError('LLM request cancelled by user')
+            await asyncio.sleep(0.05)  # Simulate some delay between chunks
+
+    with patch.object(
+        LLM, '_call_acompletion', new_callable=AsyncMock
+    ) as mock_call_acompletion:
+        mock_call_acompletion.return_value = mock_acompletion()
+        test_llm = LLM(config=config.get_llm_config())
+
+        received_chunks = []
+        with pytest.raises(UserCancelledError):
+            async for chunk in test_llm.async_streaming_completion(
+                messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
+            ):
+                received_chunks.append(chunk['choices'][0]['delta']['content'])
+                print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
+
+        # Assert that we received the expected number of chunks before cancellation
+        assert len(received_chunks) == cancel_after_chunks
+        assert received_chunks == test_messages[:cancel_after_chunks]
+
+    # Ensure the mock was called
+    mock_call_acompletion.assert_called_once()