Jelajahi Sumber

Refactor messages serialization (#3832)

Co-authored-by: Robert Brennan <accounts@rbren.io>
Engel Nyst 1 tahun lalu
induk
melakukan
8fdfece059

+ 1 - 3
agenthub/browsing_agent/browsing_agent.py

@@ -216,10 +216,8 @@ class BrowsingAgent(Agent):
         prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str)
         prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str)
         messages.append(Message(role='user', content=[TextContent(text=prompt)]))
         messages.append(Message(role='user', content=[TextContent(text=prompt)]))
 
 
-        flat_messages = self.llm.format_messages_for_llm(messages)
-
         response = self.llm.completion(
         response = self.llm.completion(
-            messages=flat_messages,
+            messages=self.llm.format_messages_for_llm(messages),
             temperature=0.0,
             temperature=0.0,
             stop=[')```', ')\n```'],
             stop=[')```', ')\n```'],
         )
         )

+ 2 - 2
config.template.toml

@@ -164,12 +164,12 @@ model = "gpt-4o"
 # If model is vision capable, this option allows to disable image processing (useful for cost reduction).
 # If model is vision capable, this option allows to disable image processing (useful for cost reduction).
 #disable_vision = true
 #disable_vision = true
 
 
-[llm.gpt3]
+[llm.gpt4o-mini]
 # API key to use
 # API key to use
 api_key = "your-api-key"
 api_key = "your-api-key"
 
 
 # Model to use
 # Model to use
-model = "gpt-3.5"
+model = "gpt-4o-mini"
 
 
 #################################### Agent ###################################
 #################################### Agent ###################################
 # Configuration for agents (group name starts with 'agent')
 # Configuration for agents (group name starts with 'agent')

+ 2 - 2
evaluation/regression/README.md

@@ -14,9 +14,9 @@ To run the tests for OpenHands project, you can use the provided test runner scr
 3. Navigate to the root directory of the project.
 3. Navigate to the root directory of the project.
 4. Run the test suite using the test runner script with the required arguments:
 4. Run the test suite using the test runner script with the required arguments:
    ```
    ```
-   python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-3.5-turbo
+   python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-4o
    ```
    ```
-   Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-3.5-turbo`, but you can specify a different model if needed.
+   Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-4o`, but you can specify a different model if needed.
 
 
 The test runner will discover and execute all the test cases in the `cases/` directory, and display the results of the test suite, including the status of each individual test case and the overall summary.
 The test runner will discover and execute all the test cases in the `cases/` directory, and display the results of the test suite, including the status of each individual test case and the overall summary.
 
 

+ 19 - 58
openhands/core/message.py

@@ -1,10 +1,7 @@
 from enum import Enum
 from enum import Enum
-from typing import Union
+from typing import Literal
 
 
 from pydantic import BaseModel, Field, model_serializer
 from pydantic import BaseModel, Field, model_serializer
-from typing_extensions import Literal
-
-from openhands.core.logger import openhands_logger as logger
 
 
 
 
 class ContentType(Enum):
 class ContentType(Enum):
@@ -60,60 +57,24 @@ class Message(BaseModel):
 
 
     @model_serializer
     @model_serializer
     def serialize_model(self) -> dict:
     def serialize_model(self) -> dict:
-        content: list[dict[str, str | dict[str, str]]] = []
-
-        for item in self.content:
-            if isinstance(item, TextContent):
-                content.append(item.model_dump())
-            elif isinstance(item, ImageContent):
-                content.extend(item.model_dump())
-
-        return {'content': content, 'role': self.role}
-
-
-def format_messages(
-    messages: Union[Message, list[Message]],
-    with_images: bool,
-    with_prompt_caching: bool,
-) -> list[dict]:
-    if not isinstance(messages, list):
-        messages = [messages]
-
-    if with_images or with_prompt_caching:
-        return [message.model_dump() for message in messages]
-
-    converted_messages = []
-    for message in messages:
-        content_parts = []
-        role = 'user'
-
-        if isinstance(message, str) and message:
-            content_parts.append(message)
-        elif isinstance(message, dict):
-            role = message.get('role', 'user')
-            if 'content' in message and message['content']:
-                content_parts.append(message['content'])
-        elif isinstance(message, Message):
-            role = message.role
-            for content in message.content:
-                if isinstance(content, list):
-                    for item in content:
-                        if isinstance(item, TextContent) and item.text:
-                            content_parts.append(item.text)
-                elif isinstance(content, TextContent) and content.text:
-                    content_parts.append(content.text)
-        else:
-            logger.error(
-                f'>>> `message` is not a string, dict, or Message: {type(message)}'
+        content: list[dict] | str
+        if self.role == 'system':
+            # For system role, concatenate all text content into a single string
+            content = '\n'.join(
+                item.text for item in self.content if isinstance(item, TextContent)
             )
             )
-
-        if content_parts:
-            content_str = '\n'.join(content_parts)
-            converted_messages.append(
-                {
-                    'role': role,
-                    'content': content_str,
-                }
+        elif self.role == 'assistant' and not self.contains_image:
+            # For assistant role without vision, concatenate all text content into a single string
+            content = '\n'.join(
+                item.text for item in self.content if isinstance(item, TextContent)
             )
             )
+        else:
+            # For user role or assistant role with vision enabled, serialize each content item
+            content = []
+            for item in self.content:
+                if isinstance(item, TextContent):
+                    content.append(item.model_dump())
+                elif isinstance(item, ImageContent):
+                    content.extend(item.model_dump())
 
 
-    return converted_messages
+        return {'content': content, 'role': self.role}

+ 5 - 8
openhands/llm/llm.py

@@ -2,7 +2,6 @@ import asyncio
 import copy
 import copy
 import warnings
 import warnings
 from functools import partial
 from functools import partial
-from typing import Union
 
 
 from openhands.core.config import LLMConfig
 from openhands.core.config import LLMConfig
 from openhands.runtime.utils.shutdown_listener import should_continue
 from openhands.runtime.utils.shutdown_listener import should_continue
@@ -32,7 +31,7 @@ from tenacity import (
 from openhands.core.exceptions import LLMResponseError, UserCancelledError
 from openhands.core.exceptions import LLMResponseError, UserCancelledError
 from openhands.core.logger import llm_prompt_logger, llm_response_logger
 from openhands.core.logger import llm_prompt_logger, llm_response_logger
 from openhands.core.logger import openhands_logger as logger
 from openhands.core.logger import openhands_logger as logger
-from openhands.core.message import Message, format_messages
+from openhands.core.message import Message
 from openhands.core.metrics import Metrics
 from openhands.core.metrics import Metrics
 
 
 __all__ = ['LLM']
 __all__ = ['LLM']
@@ -633,9 +632,7 @@ class LLM:
     def reset(self):
     def reset(self):
         self.metrics = Metrics()
         self.metrics = Metrics()
 
 
-    def format_messages_for_llm(
-        self, messages: Union[Message, list[Message]]
-    ) -> list[dict]:
-        return format_messages(
-            messages, self.vision_is_active(), self.is_caching_prompt_active()
-        )
+    def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
+        if isinstance(messages, Message):
+            return [messages.model_dump()]
+        return [message.model_dump() for message in messages]

+ 24 - 5
tests/integration/conftest.py

@@ -11,7 +11,6 @@ from http.server import HTTPServer, SimpleHTTPRequestHandler
 import pytest
 import pytest
 from litellm import completion
 from litellm import completion
 
 
-from openhands.core.message import format_messages
 from openhands.llm.llm import message_separator
 from openhands.llm.llm import message_separator
 
 
 script_dir = os.environ.get('SCRIPT_DIR')
 script_dir = os.environ.get('SCRIPT_DIR')
@@ -78,6 +77,29 @@ def get_log_id(prompt_log_name):
         return match.group(1)
         return match.group(1)
 
 
 
 
+def _format_messages(messages):
+    message_str = ''
+    for message in messages:
+        if isinstance(message, str):
+            message_str += message_separator + message if message_str else message
+        elif isinstance(message, dict):
+            if isinstance(message['content'], list):
+                for m in message['content']:
+                    if isinstance(m, str):
+                        message_str += message_separator + m if message_str else m
+                    elif isinstance(m, dict) and m['type'] == 'text':
+                        message_str += (
+                            message_separator + m['text'] if message_str else m['text']
+                        )
+            elif isinstance(message['content'], str):
+                message_str += (
+                    message_separator + message['content']
+                    if message_str
+                    else message['content']
+                )
+    return message_str
+
+
 def apply_prompt_and_get_mock_response(
 def apply_prompt_and_get_mock_response(
     test_name: str, messages: str, id: int
     test_name: str, messages: str, id: int
 ) -> str | None:
 ) -> str | None:
@@ -185,10 +207,7 @@ def mock_user_response(*args, test_name, **kwargs):
 def mock_completion(*args, test_name, **kwargs):
 def mock_completion(*args, test_name, **kwargs):
     global cur_id
     global cur_id
     messages = kwargs['messages']
     messages = kwargs['messages']
-    plain_messages = format_messages(
-        messages, with_images=False, with_prompt_caching=False
-    )
-    message_str = message_separator.join(msg['content'] for msg in plain_messages)
+    message_str = _format_messages(messages)  # text only
 
 
     # this assumes all response_(*).log filenames are in numerical order, starting from one
     # this assumes all response_(*).log filenames are in numerical order, starting from one
     cur_id += 1
     cur_id += 1