浏览代码

remove screenshot in browser observation (#1588)

* remove screenshot in browser observation

* refactor utils

* allow only dict

* fix screenshot not showing up in frontend

---------

Co-authored-by: Robert Brennan <accounts@rbren.io>
Frank Xu 1 年之前
父节点
当前提交
26dcf4fd7c

+ 1 - 30
agenthub/micro/agent.py

@@ -1,4 +1,3 @@
-import copy
 import json
 from typing import Dict, List
 
@@ -42,39 +41,11 @@ def my_encoder(obj):
         return obj.to_dict()
 
 
-def _remove_fields(obj, fields: set[str]):
-    """
-    Remove fields from an object
-
-    Parameters:
-    - obj (Object): The object to remove fields from
-    - fields (set[str]): A set of field names to remove from the object
-    """
-    if isinstance(obj, dict):
-        for field in fields:
-            if field in obj:
-                del obj[field]
-        for _, value in obj.items():
-            _remove_fields(value, fields)
-    elif isinstance(obj, list) or isinstance(obj, tuple):
-        for item in obj:
-            _remove_fields(item, fields)
-    elif hasattr(obj, '__dataclass_fields__'):
-        for field in fields:
-            if field in obj.__dataclass_fields__:
-                setattr(obj, field, None)
-        for value in obj.__dict__.values():
-            _remove_fields(value, fields)
-
-
 def to_json(obj, **kwargs):
     """
     Serialize an object to str format
     """
-    # Remove things like screenshots that shouldn't be in a prompt
-    sanitized_obj = copy.deepcopy(obj)
-    _remove_fields(sanitized_obj, {'screenshot'})
-    return json.dumps(sanitized_obj, default=my_encoder, **kwargs)
+    return json.dumps(obj, default=my_encoder, **kwargs)
 
 
 class MicroAgent(Agent):

+ 0 - 2
agenthub/monologue_agent/agent.py

@@ -114,8 +114,6 @@ class MonologueAgent(Agent):
         - event (dict): The event that will be added to monologue and memory
         """
 
-        if 'extras' in event and 'screenshot' in event['extras']:
-            del event['extras']['screenshot']
         if (
             'args' in event
             and 'output' in event['args']

+ 0 - 5
agenthub/planner_agent/prompt.py

@@ -172,11 +172,6 @@ def get_prompt(plan: Plan, history: List[Tuple[Action, Observation]]) -> str:
             latest_action = action
         if not isinstance(observation, NullObservation):
             observation_dict = observation.to_memory()
-            if (
-                'extras' in observation_dict
-                and 'screenshot' in observation_dict['extras']
-            ):
-                del observation_dict['extras']['screenshot']
             history_dicts.append(observation_dict)
     history_str = json.dumps(history_dicts, indent=2)
     current_task = plan.get_current_task()

+ 2 - 2
frontend/src/services/observations.ts

@@ -17,8 +17,8 @@ export function handleObservationMessage(message: ObservationMessage) {
       store.dispatch(appendJupyterOutput(message.content));
       break;
     case ObservationType.BROWSE:
-      if (message.extras?.screenshot) {
-        store.dispatch(setScreenshotSrc(message.extras.screenshot));
+      if (message.screenshot) {
+        store.dispatch(setScreenshotSrc(message.screenshot));
       }
       if (message.extras?.url) {
         store.dispatch(setUrl(message.extras.url));

+ 22 - 10
opendevin/events/observation/browse.py

@@ -1,6 +1,7 @@
 from dataclasses import dataclass, field
 
 from opendevin.core.schema import ObservationType
+from opendevin.events.utils import remove_fields
 
 from .observation import Observation
 
@@ -12,28 +13,39 @@ class BrowserOutputObservation(Observation):
     """
 
     url: str
-    screenshot: str
+    screenshot: str = field(repr=False)  # don't show in repr
     status_code: int = 200
     error: bool = False
     observation: str = ObservationType.BROWSE
     # do not include in the memory
     open_pages_urls: list = field(default_factory=list)
     active_page_index: int = -1
-    dom_object: dict = field(default_factory=dict)
-    axtree_object: dict = field(default_factory=dict)
+    dom_object: dict = field(default_factory=dict, repr=False)  # don't show in repr
+    axtree_object: dict = field(default_factory=dict, repr=False)  # don't show in repr
     last_browser_action: str = ''
     focused_element_bid: str = ''
 
+    def to_dict(self):
+        dictionary = super().to_dict()
+        # add screenshot for frontend showcase only, not for agent consumption
+        dictionary['screenshot'] = self.screenshot
+        return dictionary
+
     def to_memory(self) -> dict:
         memory_dict = super().to_memory()
         # remove some fields from the memory, as currently they are too big for LLMs
-        # TODO: find a more elegant way to handle this
-        memory_dict['extras'].pop('dom_object', None)
-        memory_dict['extras'].pop('axtree_object', None)
-        memory_dict['extras'].pop('open_pages_urls', None)
-        memory_dict['extras'].pop('active_page_index', None)
-        memory_dict['extras'].pop('last_browser_action', None)
-        memory_dict['extras'].pop('focused_element_bid', None)
+        remove_fields(
+            memory_dict['extras'],
+            {
+                'screenshot',
+                'dom_object',
+                'axtree_object',
+                'open_pages_urls',
+                'active_page_index',
+                'last_browser_action',
+                'focused_element_bid',
+            },
+        )
         return memory_dict
 
     @property

+ 21 - 0
opendevin/events/utils.py

@@ -0,0 +1,21 @@
+def remove_fields(obj, fields: set[str]):
+    """
+    Remove fields from an object.
+
+    Parameters:
+    - obj: The dictionary, or list of dictionaries to remove fields from
+    - fields (set[str]): A set of field names to remove from the object
+    """
+    if isinstance(obj, dict):
+        for field in fields:
+            if field in obj:
+                del obj[field]
+        for _, value in obj.items():
+            remove_fields(value, fields)
+    elif isinstance(obj, list) or isinstance(obj, tuple):
+        for item in obj:
+            remove_fields(item, fields)
+    elif hasattr(obj, '__dataclass_fields__'):
+        raise ValueError(
+            'Object must not contain dataclass, consider converting to dict first'
+        )