Просмотр исходного кода

Refactor messages serialization (#3832)

Co-authored-by: Robert Brennan <accounts@rbren.io>
Engel Nyst 1 год назад
Родитель
Сommit
8fdfece059

+ 1 - 3
agenthub/browsing_agent/browsing_agent.py

@@ -216,10 +216,8 @@ class BrowsingAgent(Agent):
         prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str)
         messages.append(Message(role='user', content=[TextContent(text=prompt)]))
 
-        flat_messages = self.llm.format_messages_for_llm(messages)
-
         response = self.llm.completion(
-            messages=flat_messages,
+            messages=self.llm.format_messages_for_llm(messages),
             temperature=0.0,
             stop=[')```', ')\n```'],
         )

+ 2 - 2
config.template.toml

@@ -164,12 +164,12 @@ model = "gpt-4o"
 # If model is vision capable, this option allows to disable image processing (useful for cost reduction).
 #disable_vision = true
 
-[llm.gpt3]
+[llm.gpt4o-mini]
 # API key to use
 api_key = "your-api-key"
 
 # Model to use
-model = "gpt-3.5"
+model = "gpt-4o-mini"
 
 #################################### Agent ###################################
 # Configuration for agents (group name starts with 'agent')

+ 2 - 2
evaluation/regression/README.md

@@ -14,9 +14,9 @@ To run the tests for OpenHands project, you can use the provided test runner scr
 3. Navigate to the root directory of the project.
 4. Run the test suite using the test runner script with the required arguments:
    ```
-   python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-3.5-turbo
+   python evaluation/regression/run_tests.py --OPENAI_API_KEY=sk-xxxxxxxxxxxxxxxxxxxxxx --model=gpt-4o
    ```
-   Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-3.5-turbo`, but you can specify a different model if needed.
+   Replace `sk-xxxxxxxxxxxxxxxxxxxxxx` with your actual OpenAI API key. The default model is `gpt-4o`, but you can specify a different model if needed.
 
 The test runner will discover and execute all the test cases in the `cases/` directory, and display the results of the test suite, including the status of each individual test case and the overall summary.
 

+ 19 - 58
openhands/core/message.py

@@ -1,10 +1,7 @@
 from enum import Enum
-from typing import Union
+from typing import Literal
 
 from pydantic import BaseModel, Field, model_serializer
-from typing_extensions import Literal
-
-from openhands.core.logger import openhands_logger as logger
 
 
 class ContentType(Enum):
@@ -60,60 +57,24 @@ class Message(BaseModel):
 
     @model_serializer
     def serialize_model(self) -> dict:
-        content: list[dict[str, str | dict[str, str]]] = []
-
-        for item in self.content:
-            if isinstance(item, TextContent):
-                content.append(item.model_dump())
-            elif isinstance(item, ImageContent):
-                content.extend(item.model_dump())
-
-        return {'content': content, 'role': self.role}
-
-
-def format_messages(
-    messages: Union[Message, list[Message]],
-    with_images: bool,
-    with_prompt_caching: bool,
-) -> list[dict]:
-    if not isinstance(messages, list):
-        messages = [messages]
-
-    if with_images or with_prompt_caching:
-        return [message.model_dump() for message in messages]
-
-    converted_messages = []
-    for message in messages:
-        content_parts = []
-        role = 'user'
-
-        if isinstance(message, str) and message:
-            content_parts.append(message)
-        elif isinstance(message, dict):
-            role = message.get('role', 'user')
-            if 'content' in message and message['content']:
-                content_parts.append(message['content'])
-        elif isinstance(message, Message):
-            role = message.role
-            for content in message.content:
-                if isinstance(content, list):
-                    for item in content:
-                        if isinstance(item, TextContent) and item.text:
-                            content_parts.append(item.text)
-                elif isinstance(content, TextContent) and content.text:
-                    content_parts.append(content.text)
-        else:
-            logger.error(
-                f'>>> `message` is not a string, dict, or Message: {type(message)}'
+        content: list[dict] | str
+        if self.role == 'system':
+            # For system role, concatenate all text content into a single string
+            content = '\n'.join(
+                item.text for item in self.content if isinstance(item, TextContent)
             )
-
-        if content_parts:
-            content_str = '\n'.join(content_parts)
-            converted_messages.append(
-                {
-                    'role': role,
-                    'content': content_str,
-                }
+        elif self.role == 'assistant' and not self.contains_image:
+            # For assistant role without vision, concatenate all text content into a single string
+            content = '\n'.join(
+                item.text for item in self.content if isinstance(item, TextContent)
             )
+        else:
+            # For user role or assistant role with vision enabled, serialize each content item
+            content = []
+            for item in self.content:
+                if isinstance(item, TextContent):
+                    content.append(item.model_dump())
+                elif isinstance(item, ImageContent):
+                    content.extend(item.model_dump())
 
-    return converted_messages
+        return {'content': content, 'role': self.role}

+ 5 - 8
openhands/llm/llm.py

@@ -2,7 +2,6 @@ import asyncio
 import copy
 import warnings
 from functools import partial
-from typing import Union
 
 from openhands.core.config import LLMConfig
 from openhands.runtime.utils.shutdown_listener import should_continue
@@ -32,7 +31,7 @@ from tenacity import (
 from openhands.core.exceptions import LLMResponseError, UserCancelledError
 from openhands.core.logger import llm_prompt_logger, llm_response_logger
 from openhands.core.logger import openhands_logger as logger
-from openhands.core.message import Message, format_messages
+from openhands.core.message import Message
 from openhands.core.metrics import Metrics
 
 __all__ = ['LLM']
@@ -633,9 +632,7 @@ class LLM:
     def reset(self):
         self.metrics = Metrics()
 
-    def format_messages_for_llm(
-        self, messages: Union[Message, list[Message]]
-    ) -> list[dict]:
-        return format_messages(
-            messages, self.vision_is_active(), self.is_caching_prompt_active()
-        )
+    def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
+        if isinstance(messages, Message):
+            return [messages.model_dump()]
+        return [message.model_dump() for message in messages]

+ 24 - 5
tests/integration/conftest.py

@@ -11,7 +11,6 @@ from http.server import HTTPServer, SimpleHTTPRequestHandler
 import pytest
 from litellm import completion
 
-from openhands.core.message import format_messages
 from openhands.llm.llm import message_separator
 
 script_dir = os.environ.get('SCRIPT_DIR')
@@ -78,6 +77,29 @@ def get_log_id(prompt_log_name):
         return match.group(1)
 
 
+def _format_messages(messages):
+    message_str = ''
+    for message in messages:
+        if isinstance(message, str):
+            message_str += message_separator + message if message_str else message
+        elif isinstance(message, dict):
+            if isinstance(message['content'], list):
+                for m in message['content']:
+                    if isinstance(m, str):
+                        message_str += message_separator + m if message_str else m
+                    elif isinstance(m, dict) and m['type'] == 'text':
+                        message_str += (
+                            message_separator + m['text'] if message_str else m['text']
+                        )
+            elif isinstance(message['content'], str):
+                message_str += (
+                    message_separator + message['content']
+                    if message_str
+                    else message['content']
+                )
+    return message_str
+
+
 def apply_prompt_and_get_mock_response(
     test_name: str, messages: str, id: int
 ) -> str | None:
@@ -185,10 +207,7 @@ def mock_user_response(*args, test_name, **kwargs):
 def mock_completion(*args, test_name, **kwargs):
     global cur_id
     messages = kwargs['messages']
-    plain_messages = format_messages(
-        messages, with_images=False, with_prompt_caching=False
-    )
-    message_str = message_separator.join(msg['content'] for msg in plain_messages)
+    message_str = _format_messages(messages)  # text only
 
     # this assumes all response_(*).log filenames are in numerical order, starting from one
     cur_id += 1