瀏覽代碼

Dump trajectories with delegate history if configured (#4336)

Boxuan Li 1 年之前
父節點
當前提交
7186224899

+ 3 - 0
config.template.toml

@@ -28,6 +28,9 @@ workspace_base = "./workspace"
 # Enable saving and restoring the session when run from CLI
 #enable_cli_session = false
 
+# Path to store trajectories
+#trajectories_path="./trajectories"
+
 # File store path
 #file_store_path = "/tmp/file_store"
 

+ 2 - 0
openhands/core/config/app_config.py

@@ -28,6 +28,7 @@ class AppConfig:
         runtime: The runtime environment.
         file_store: The file store to use.
         file_store_path: The path to the file store.
+        trajectories_path: The folder path to store trajectories.
         workspace_base: The base path for the workspace. Defaults to ./workspace as an absolute path.
         workspace_mount_path: The path to mount the workspace. This is set to the workspace base by default.
         workspace_mount_path_in_sandbox: The path to mount the workspace in the sandbox. Defaults to /workspace.
@@ -53,6 +54,7 @@ class AppConfig:
     runtime: str = 'eventstream'
     file_store: str = 'memory'
     file_store_path: str = '/tmp/file_store'
+    trajectories_path: str | None = None
     # TODO: clean up workspace path after the removal of ServerRuntime
     workspace_base: str = os.path.join(os.getcwd(), 'workspace')
     workspace_mount_path: str | None = (

+ 14 - 0
openhands/core/main.py

@@ -1,5 +1,7 @@
 import asyncio
 import hashlib
+import json
+import os
 import sys
 import uuid
 from typing import Callable, Protocol, Type
@@ -21,6 +23,7 @@ from openhands.events.action import MessageAction
 from openhands.events.action.action import Action
 from openhands.events.event import Event
 from openhands.events.observation import AgentStateChangedObservation
+from openhands.events.serialization.event import event_to_trajectory
 from openhands.llm.llm import LLM
 from openhands.runtime import get_runtime_cls
 from openhands.runtime.runtime import Runtime
@@ -202,6 +205,17 @@ async def run_controller(
     await controller.close()
     state = controller.get_state()
 
+    # save trajectories if applicable
+    if config.trajectories_path is not None:
+        file_path = os.path.join(config.trajectories_path, sid + '.json')
+        os.makedirs(os.path.dirname(file_path), exist_ok=True)
+        histories = [
+            event_to_trajectory(event)
+            for event in state.history.get_events(include_delegates=True)
+        ]
+        with open(file_path, 'w') as f:
+            json.dump(histories, f)
+
     return state
 
 

+ 2 - 0
openhands/events/serialization/__init__.py

@@ -5,6 +5,7 @@ from openhands.events.serialization.event import (
     event_from_dict,
     event_to_dict,
     event_to_memory,
+    event_to_trajectory,
 )
 from openhands.events.serialization.observation import (
     observation_from_dict,
@@ -15,5 +16,6 @@ __all__ = [
     'event_from_dict',
     'event_to_dict',
     'event_to_memory',
+    'event_to_trajectory',
     'observation_from_dict',
 ]

+ 10 - 2
openhands/events/serialization/event.py

@@ -11,11 +11,10 @@ from openhands.events.serialization.utils import remove_fields
 TOP_KEYS = ['id', 'timestamp', 'source', 'message', 'cause', 'action', 'observation']
 UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause']
 
-DELETE_FROM_MEMORY_EXTRAS = {
+DELETE_FROM_TRAJECTORY_EXTRAS = {
     'screenshot',
     'dom_object',
     'axtree_object',
-    'open_pages_urls',
     'active_page_index',
     'last_browser_action',
     'last_browser_action_error',
@@ -23,6 +22,8 @@ DELETE_FROM_MEMORY_EXTRAS = {
     'extra_element_properties',
 }
 
+DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'}
+
 
 def event_from_dict(data) -> 'Event':
     evt: Event
@@ -73,6 +74,13 @@ def event_to_dict(event: 'Event') -> dict:
     return d
 
 
+def event_to_trajectory(event: 'Event') -> dict:
+    d = event_to_dict(event)
+    if 'extras' in d:
+        remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS)
+    return d
+
+
 def event_to_memory(event: 'Event', max_message_chars: int) -> dict:
     d = event_to_dict(event)
     d.pop('id', None)

+ 5 - 0
tests/unit/test_observation_serialization.py

@@ -6,6 +6,7 @@ from openhands.events.serialization import (
     event_from_dict,
     event_to_dict,
     event_to_memory,
+    event_to_trajectory,
 )
 
 
@@ -20,12 +21,16 @@ def serialization_deserialization(
         observation_instance, cls
     ), 'The observation instance should be an instance of CmdOutputObservation.'
     serialized_observation_dict = event_to_dict(observation_instance)
+    serialized_observation_trajectory = event_to_trajectory(observation_instance)
     serialized_observation_memory = event_to_memory(
         observation_instance, max_message_chars
     )
     assert (
         serialized_observation_dict == original_observation_dict
     ), 'The serialized observation should match the original observation dict.'
+    assert (
+        serialized_observation_trajectory == original_observation_dict
+    ), 'The serialized observation trajectory should match the original observation dict.'
     original_observation_dict.pop('message', None)
     original_observation_dict.pop('id', None)
     original_observation_dict.pop('timestamp', None)