Quellcode durchsuchen

Serialization of Actions and Observations (#314)

* checkout geohotstan work

* merge session.py changes

* add observation ids

* ignore null actions and obs

* add back action messages

* fix lint
Robert Brennan vor 2 Jahren
Ursprung
Commit
32a3a0259a

+ 21 - 0
opendevin/action/__init__.py

@@ -4,6 +4,27 @@ from .browse import BrowseURLAction
 from .fileop import FileReadAction, FileWriteAction
 from .agent import AgentRecallAction, AgentThinkAction, AgentFinishAction, AgentEchoAction, AgentSummarizeAction
 
+actions = (
+    CmdKillAction,
+    CmdRunAction,
+    BrowseURLAction,
+    FileReadAction,
+    FileWriteAction,
+    AgentRecallAction,
+    AgentThinkAction,
+    AgentFinishAction
+)
+
+ACTION_TYPE_TO_CLASS = {action_class.action:action_class for action_class in actions} # type: ignore[attr-defined]
+
+def action_class_initialize_dispatcher(action: str, *args: str, **kwargs: str) -> Action:
+    action_class = ACTION_TYPE_TO_CLASS.get(action)
+    if action_class is None:
+        raise KeyError(f"'{action=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}")
+    return action_class(*args, **kwargs)
+
+CLASS_TO_ACTION_TYPE = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
+
 __all__ = [
     "Action",
     "NullAction",

+ 6 - 4
opendevin/action/agent.py

@@ -10,6 +10,7 @@ if TYPE_CHECKING:
 @dataclass
 class AgentRecallAction(ExecutableAction):
     query: str
+    action: str = "recall"
 
     def run(self, controller: "AgentController") -> AgentRecallObservation:
         return AgentRecallObservation(
@@ -21,12 +22,11 @@ class AgentRecallAction(ExecutableAction):
     def message(self) -> str:
         return f"Let me dive into my memories to find what you're looking for! Searching for: '{self.query}'. This might take a moment."
 
-
-
 @dataclass
 class AgentThinkAction(NotExecutableAction):
     thought: str
     runnable: bool = False
+    action: str = "think"
 
     def run(self, controller: "AgentController") -> "Observation":
         raise NotImplementedError
@@ -35,11 +35,11 @@ class AgentThinkAction(NotExecutableAction):
     def message(self) -> str:
         return self.thought
 
-
 @dataclass
 class AgentEchoAction(ExecutableAction):
     content: str
     runnable: bool = True
+    action: str = "echo"
 
     def run(self, controller: "AgentController") -> "Observation":
         return AgentMessageObservation(self.content)
@@ -52,6 +52,8 @@ class AgentEchoAction(ExecutableAction):
 class AgentSummarizeAction(NotExecutableAction):
     summary: str
 
+    action: str = "summarize"
+
     @property
     def message(self) -> str:
         return self.summary
@@ -59,6 +61,7 @@ class AgentSummarizeAction(NotExecutableAction):
 @dataclass
 class AgentFinishAction(NotExecutableAction):
     runnable: bool = False
+    action: str = "finish"
 
     def run(self, controller: "AgentController") -> "Observation":
         raise NotImplementedError
@@ -66,4 +69,3 @@ class AgentFinishAction(NotExecutableAction):
     @property
     def message(self) -> str:
         return "All done! What's next on the agenda?"
-

+ 11 - 3
opendevin/action/base.py

@@ -1,5 +1,4 @@
-from dataclasses import dataclass
-
+from dataclasses import dataclass, asdict
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
@@ -12,7 +11,12 @@ class Action:
         raise NotImplementedError
 
     def to_dict(self):
-        return {"action": self.__class__.__name__, "args": self.__dict__, "message": self.message}
+        d = asdict(self)
+        try:
+            v = d.pop('action')
+        except KeyError:
+            raise NotImplementedError(f'{self=} does not have action attribute set')
+        return {'action': v, "args": d, "message": self.message}
 
     @property
     def executable(self) -> bool:
@@ -24,21 +28,25 @@ class Action:
 
 
 
+@dataclass
 class ExecutableAction(Action):
     @property
     def executable(self) -> bool:
         return True
 
 
+@dataclass
 class NotExecutableAction(Action):
     @property
     def executable(self) -> bool:
         return False
 
+@dataclass
 class NullAction(NotExecutableAction):
     """An action that does nothing.
     This is used when the agent need to receive user follow-up messages from the frontend.
     """
+    action: str = "null"
 
     @property
     def message(self) -> str:

+ 3 - 1
opendevin/action/bash.py

@@ -12,6 +12,7 @@ if TYPE_CHECKING:
 class CmdRunAction(ExecutableAction):
     command: str
     background: bool = False
+    action: str = "run"
 
     def run(self, controller: "AgentController") -> "CmdOutputObservation":
         return controller.command_manager.run_command(self.command, self.background)
@@ -23,10 +24,11 @@ class CmdRunAction(ExecutableAction):
 @dataclass
 class CmdKillAction(ExecutableAction):
     id: int
+    action: str = "kill"
 
     def run(self, controller: "AgentController") -> "CmdOutputObservation":
         return controller.command_manager.kill_command(self.id)
 
     @property
     def message(self) -> str:
-        return f"Killing command: {self.id}"
+        return f"Killing command: {self.id}"

+ 2 - 1
opendevin/action/browse.py

@@ -8,6 +8,7 @@ from .base import ExecutableAction
 @dataclass
 class BrowseURLAction(ExecutableAction):
     url: str
+    action: str = "browse"
 
     def run(self, *args, **kwargs) -> BrowserOutputObservation:
         try:
@@ -26,4 +27,4 @@ class BrowseURLAction(ExecutableAction):
 
     @property
     def message(self) -> str:
-        return f"Browsing URL: {self.url}"
+        return f"Browsing URL: {self.url}"

+ 2 - 2
opendevin/action/fileop.py

@@ -17,6 +17,7 @@ def resolve_path(base_path, file_path):
 @dataclass
 class FileReadAction(ExecutableAction):
     path: str
+    action: str = "read"
 
     def run(self, controller) -> FileReadObservation:
         path = resolve_path(controller.workdir, self.path)
@@ -29,11 +30,11 @@ class FileReadAction(ExecutableAction):
     def message(self) -> str:
         return f"Reading file: {self.path}"
 
-
 @dataclass
 class FileWriteAction(ExecutableAction):
     path: str
     contents: str
+    action: str = "write"
 
     def run(self, controller) -> FileWriteObservation:
         path = resolve_path(controller.workdir, self.path)
@@ -44,4 +45,3 @@ class FileWriteAction(ExecutableAction):
     @property
     def message(self) -> str:
         return f"Writing file: {self.path}"
-

+ 13 - 1
opendevin/observation.py

@@ -18,8 +18,11 @@ class Observation:
         """Converts the observation to a dictionary."""
         extras = copy.deepcopy(self.__dict__)
         extras.pop("content", None)
+        observation = "observation"
+        if hasattr(self, "observation"):
+            observation = self.observation
         return {
-            "observation": self.__class__.__name__,
+            "observation": observation,
             "content": self.content,
             "extras": extras,
             "message": self.message,
@@ -40,6 +43,7 @@ class CmdOutputObservation(Observation):
     command_id: int
     command: str
     exit_code: int = 0
+    observation : str = "run"
 
     @property
     def error(self) -> bool:
@@ -56,6 +60,7 @@ class FileReadObservation(Observation):
     """
 
     path: str
+    observation : str = "read"
 
     @property
     def message(self) -> str:
@@ -68,6 +73,7 @@ class FileWriteObservation(Observation):
     """
 
     path: str
+    observation : str = "write"
 
     @property
     def message(self) -> str:
@@ -82,6 +88,7 @@ class BrowserOutputObservation(Observation):
     url: str
     status_code: int = 200
     error: bool = False
+    observation : str = "browse"
 
     @property
     def message(self) -> str:
@@ -95,6 +102,7 @@ class UserMessageObservation(Observation):
     """
 
     role: str = "user"
+    observation : str = "message"
 
     @property
     def message(self) -> str:
@@ -108,6 +116,7 @@ class AgentMessageObservation(Observation):
     """
 
     role: str = "assistant"
+    observation : str = "message"
 
     @property
     def message(self) -> str:
@@ -122,6 +131,7 @@ class AgentRecallObservation(Observation):
 
     memories: List[str]
     role: str = "assistant"
+    observation : str = "recall"
 
     @property
     def message(self) -> str:
@@ -133,6 +143,7 @@ class AgentErrorObservation(Observation):
     """
     This data class represents an error encountered by the agent.
     """
+    observation : str = "error"
 
     @property
     def message(self) -> str:
@@ -144,6 +155,7 @@ class NullObservation(Observation):
     This data class represents a null observation.
     This is used when the produced action is NOT executable.
     """
+    observation : str = "null"
 
     @property
     def message(self) -> str:

+ 14 - 74
opendevin/server/session.py

@@ -1,58 +1,22 @@
 import asyncio
 import os
-from typing import Dict, Optional, Type
+from typing import Optional
 
 from fastapi import WebSocketDisconnect
 
 from opendevin.action import (
     Action,
-    AgentFinishAction,
-    AgentRecallAction,
-    AgentThinkAction,
-    BrowseURLAction,
-    CmdKillAction,
-    CmdRunAction,
-    FileReadAction,
-    FileWriteAction,
     NullAction,
 )
+from opendevin.observation import NullObservation
 from opendevin.agent import Agent
 from opendevin.controller import AgentController
 from opendevin.llm.llm import LLM
 from opendevin.observation import Observation, UserMessageObservation
 
-# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
-ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
-    "run": CmdRunAction,
-    "kill": CmdKillAction,
-    "browse": BrowseURLAction,
-    "read": FileReadAction,
-    "write": FileWriteAction,
-    "recall": AgentRecallAction,
-    "think": AgentThinkAction,
-    "finish": AgentFinishAction,
-}
-
-
 DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
 LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4-0125-preview")
 
-def parse_event(data):
-    if "action" not in data:
-        return None
-    action = data["action"]
-    args = {}
-    if "args" in data:
-        args = data["args"]
-    message = None
-    if "message" in data:
-        message = data["message"]
-    return {
-        "action": action,
-        "args": args,
-        "message": message,
-    }
-
 class Session:
     def __init__(self, websocket):
         self.websocket = websocket
@@ -84,20 +48,20 @@ class Session:
                     await self.send_error("Invalid JSON")
                     continue
 
-                event = parse_event(data)
-                if event is None:
+                action = data.get("action", None)
+                if action is None:
                     await self.send_error("Invalid event")
                     continue
-                if event["action"] == "initialize":
-                    await self.create_controller(event)
-                elif event["action"] == "start":
-                    await self.start_task(event)
+                if action == "initialize":
+                    await self.create_controller(data)
+                elif action == "start":
+                    await self.start_task(data)
                 else:
                     if self.controller is None:
                         await self.send_error("No agent started. Please wait a second...")
 
-                    elif event["action"] == "chat":
-                        self.controller.add_history(NullAction(), UserMessageObservation(event["message"]))
+                    elif action == "chat":
+                        self.controller.add_history(NullAction(), UserMessageObservation(data["message"]))
                     else:
                         # TODO: we only need to implement user message for now
                         # since even Devin does not support having the user taking other
@@ -147,33 +111,9 @@ class Session:
         self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
 
     def on_agent_event(self, event: Observation | Action):
-        # FIXME: we need better serialization
+        if isinstance(event, NullAction):
+            return
+        if isinstance(event, NullObservation):
+            return
         event_dict = event.to_dict()
-        if "action" in event_dict:
-            if event_dict["action"] == "CmdRunAction":
-                event_dict["action"] = "run"
-            elif event_dict["action"] == "CmdKillAction":
-                event_dict["action"] = "kill"
-            elif event_dict["action"] == "BrowseURLAction":
-                event_dict["action"] = "browse"
-            elif event_dict["action"] == "FileReadAction":
-                event_dict["action"] = "read"
-            elif event_dict["action"] == "FileWriteAction":
-                event_dict["action"] = "write"
-            elif event_dict["action"] == "AgentFinishAction":
-                event_dict["action"] = "finish"
-            elif event_dict["action"] == "AgentRecallAction":
-                event_dict["action"] = "recall"
-            elif event_dict["action"] == "AgentThinkAction":
-                event_dict["action"] = "think"
-        if "observation" in event_dict:
-            if event_dict["observation"] == "UserMessageObservation":
-                event_dict["observation"] = "chat"
-            elif event_dict["observation"] == "AgentMessageObservation":
-                event_dict["observation"] = "chat"
-            elif event_dict["observation"] == "CmdOutputObservation":
-                event_dict["observation"] = "run"
-            elif event_dict["observation"] == "FileReadObservation":
-                event_dict["observation"] = "read"
-
         asyncio.create_task(self.send(event_dict), name="send event in callback")