Sfoglia il codice sorgente

New Agent, Action, Observation Abstraction with updated Controller (#105)

* rearrange workspace_dir and max_step as arguments to controller

* remove unused output

* abstract each action into dataclass

* move actions

* fix action import

* move cmd manager and change method to private

* move controller

* rename action folder

* add state

* a draft of Controller & new agent abstraction

* add agent actions

* remove controller file

* add observation to perform a refractor on langchains agent

* revert to make this compatible via translation

* fix typo and translate error

* add error to observation

* index thought as dict

* refractor controller

* fix circular dependency caused by type hint

* add runnable attribute to agent

* add mixin to denote executable

* change baseclass

* make file read/write action compatible w/ docker directory

* remove event

* fix some merge issue

* fix sandbox w/ permission issue

* cleanup history abstraction since langchains agent is not really using it

* tweak to make langchains agent working

* make all actions return observation

* fix missing import

* add echo action for agent

* add error code to cmd output obs

* make cmd manager returns cmd output obs

* fix codeact agent to make it work

* fix all ruff issue

* fix mypy

* add import agenthub back

* add message for Action attribute (migrate from previous event)

* fix typo

* fix instruction setting

* fix instruction setting

* attempt to fix session

* ruff fix

* add .to_dict method for base and observation

* add message for recall

* try to simplify the state_updated_info with tuple of action and obs

* update_info to Tuple[Action, Observation]

* make codeact agent and langchains compatible with Tuple[Action, Observation]

* fix ruff

* fix ruff

* change to base path to fix minimal langchains agent

* add NullAction to potentially handle for chat scenario

* Update opendevin/controller/command_manager.py

Co-authored-by: Robert Brennan <accounts@rbren.io>

* fix event args

* set the default workspace to "workspace"

* make directory relative (so it does not show up to agent in File*Action)

* fix typo

* await to yield for sending observation

* fix message format

---------

Co-authored-by: Robert Brennan <accounts@rbren.io>
Xingyao Wang 1 anno fa
parent
commit
82f934d4cd

+ 76 - 70
agenthub/codeact_agent/__init__.py

@@ -2,12 +2,21 @@ import os
 import re
 from litellm import completion
 from termcolor import colored
-from typing import List, Dict
+from typing import List, Mapping
+
+from opendevin.agent import Agent
+from opendevin.state import State
+from opendevin.action import (
+    Action,
+    CmdRunAction,
+    AgentEchoAction,
+    AgentFinishAction,
+)
+from opendevin.observation import (
+    CmdOutputObservation,
+    AgentMessageObservation,
+)
 
-from opendevin.agent import Agent, Message, Role
-from opendevin.lib.event import Event
-from opendevin.lib.command_manager import CommandManager
-from opendevin.sandbox.sandbox import DockerInteractive
 
 assert (
     "OPENAI_API_KEY" in os.environ
@@ -53,9 +62,7 @@ def parse_response(response) -> str:
 class CodeActAgent(Agent):
     def __init__(
         self,
-        instruction: str,
-        workspace_dir: str,
-        max_steps: int = 100
+        model_name: str
     ) -> None:
         """
         Initializes a new instance of the CodeActAgent class.
@@ -64,69 +71,68 @@ class CodeActAgent(Agent):
         - instruction (str): The instruction for the agent to execute.
         - max_steps (int): The maximum number of steps to run the agent.
         """
-        super().__init__(instruction, workspace_dir, max_steps)
-        self._history = [Message(Role.SYSTEM, SYSTEM_MESSAGE)]
-        self._history.append(Message(Role.USER, instruction))
-        self.env = DockerInteractive(workspace_dir=workspace_dir)
-        print(colored("===USER:===\n" + instruction, "green"))
-
-    def _history_to_messages(self) -> List[Dict]:
-        return [message.to_dict() for message in self._history]
-
-    def run(self) -> None:
-        """
-        Starts the execution of the assigned instruction. This method should
-        be implemented by subclasses to define the specific execution logic.
-        """
-        for _ in range(self.max_steps):
-            response = completion(
-                messages=self._history_to_messages(),
-                model=self.model_name,
-                stop=["</execute>"],
-                temperature=0.0,
-                seed=42,
-            )
-            action = parse_response(response)
-            self._history.append(Message(Role.ASSISTANT, action))
-            print(colored("===ASSISTANT:===\n" + action, "yellow"))
-
-            command = re.search(r"<execute>(.*)</execute>", action, re.DOTALL)
-            if command is not None:
-                # a command was found
-                command_group = command.group(1)
-                if command_group.strip() == "exit":
-                    print(colored("Exit received. Exiting...", "red"))
-                    break
-                # execute the code
-                # TODO: does exit_code get loaded into Message?
-                exit_code, observation = self.env.execute(command_group)
-                self._history.append(Message(Role.ASSISTANT, observation))
-                print(colored("===ENV OBSERVATION:===\n" + observation, "blue"))
-            else:
-                # we could provide a error message for the model to continue similar to
-                # https://github.com/xingyaoww/mint-bench/blob/main/mint/envs/general_env.py#L18-L23
-                observation = INVALID_INPUT_MESSAGE
-                self._history.append(Message(Role.ASSISTANT, observation))
-                print(colored("===ENV OBSERVATION:===\n" + observation, "blue"))
-
-        self.env.close()
-
-    def chat(self, message: str) -> None:
-        """
-        Optional method for interactive communication with the agent during its execution. Implementations
-        can use this method to modify the agent's behavior or state based on chat inputs.
+        super().__init__(model_name)
+        self.messages: List[Mapping[str, str]] = []
+        self.instruction: str = ""
+
+    def step(self, state: State) -> Action:
+        if len(self.messages) == 0:
+            assert self.instruction, "Expecting instruction to be set"
+            self.messages = [
+                {"role": "system", "content": SYSTEM_MESSAGE},
+                {"role": "user", "content": self.instruction},
+            ]
+            print(colored("===USER:===\n" + self.instruction, "green"))
+
+        updated_info = state.updated_info
+
+        if updated_info:
+            for prev_action, obs in updated_info:
+                assert isinstance(prev_action, (CmdRunAction, AgentEchoAction)), "Expecting CmdRunAction or AgentEchoAction for Action"
+
+                if isinstance(obs, AgentMessageObservation):  # warning message from itself
+                    self.messages.append({"role": "user", "content": obs.content})
+                    print(colored("===USER:===\n" + obs.content, "green"))
+                elif isinstance(obs, CmdOutputObservation):
+                    content = "OBSERVATION:\n" + obs.content
+                    content += f"\n[Command {obs.command_id} finished with exit code {obs.exit_code}]]"
+                    self.messages.append({"role": "user", "content": content})
+                    print(colored("===ENV OBSERVATION:===\n" + content, "blue"))
+                else:
+                    raise NotImplementedError(f"Unknown observation type: {obs.__class__}")
+
+        response = completion(
+            messages=self.messages,
+            model=self.model_name,
+            stop=["</execute>"],
+            temperature=0.0,
+            seed=42,
+        )
+        action_str: str = parse_response(response)
+        self.messages.append({"role": "assistant", "content": action_str})
+        print(colored("===ASSISTANT:===\n" + action_str, "yellow"))
+
+        command = re.search(r"<execute>(.*)</execute>", action_str, re.DOTALL)
+        if command is not None:
+            # a command was found
+            command_group = command.group(1)
+            if command_group.strip() == "exit":
+                print(colored("Exit received. Exiting...", "red"))
+                return AgentFinishAction()
+            return CmdRunAction(command = command_group)
+            # # execute the code
+            # # TODO: does exit_code get loaded into Message?
+            # exit_code, observation = self.env.execute(command_group)
+            # self._history.append(Message(Role.ASSISTANT, observation))
+            # print(colored("===ENV OBSERVATION:===\n" + observation, "blue"))
+        else:
+            # we could provide a error message for the model to continue similar to
+            # https://github.com/xingyaoww/mint-bench/blob/main/mint/envs/general_env.py#L18-L23
+            # observation = INVALID_INPUT_MESSAGE
+            # self._history.append(Message(Role.ASSISTANT, observation))
+            # print(colored("===ENV OBSERVATION:===\n" + observation, "blue"))
+            return AgentEchoAction(content=INVALID_INPUT_MESSAGE)  # warning message to itself
 
-        Parameters:
-        - message (str): The chat message or command.
-        """
-        raise NotImplementedError
-
-    # TODO: implement these abstract methods
-    def add_event(self, event: Event) -> None:
-        raise NotImplementedError("Implement this abstract method")
-
-    def step(self, cmd_mgr: CommandManager) -> Event:
-        raise NotImplementedError("Implement this abstract method")
 
     def search_memory(self, query: str) -> List[str]:
         raise NotImplementedError("Implement this abstract method")

+ 123 - 27
agenthub/langchains_agent/__init__.py

@@ -1,8 +1,28 @@
-from typing import List, Any
+from typing import List, Dict, Type
 
+import agenthub.langchains_agent.utils.llm as llm
 from opendevin.agent import Agent
-from agenthub.langchains_agent.utils.agent import Agent as LangchainsAgentImpl
-from opendevin.lib.event import Event
+from opendevin.action import (
+    Action,
+    CmdRunAction,
+    CmdKillAction,
+    BrowseURLAction,
+    FileReadAction,
+    FileWriteAction,
+    AgentRecallAction,
+    AgentThinkAction,
+    AgentFinishAction,
+)
+from opendevin.observation import (
+    Observation,
+    CmdOutputObservation,
+    BrowserOutputObservation,
+)
+from opendevin.state import State
+
+from agenthub.langchains_agent.utils.monologue import Monologue
+from agenthub.langchains_agent.utils.memory import LongTermMemory
+
 
 INITIAL_THOUGHTS = [
     "I exist!",
@@ -43,59 +63,135 @@ INITIAL_THOUGHTS = [
 ]
 
 
+MAX_OUTPUT_LENGTH = 5000
+MAX_MONOLOGUE_LENGTH = 20000
+
+
+ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
+    "run": CmdRunAction,
+    "kill": CmdKillAction,
+    "browse": BrowseURLAction,
+    "read": FileReadAction,
+    "write": FileWriteAction,
+    "recall": AgentRecallAction,
+    "think": AgentThinkAction,
+    "finish": AgentFinishAction,
+}
+
+CLASS_TO_ACTION_TYPE: Dict[Type[Action], str] = {v: k for k, v in ACTION_TYPE_TO_CLASS.items()}
+
 class LangchainsAgent(Agent):
     _initialized = False
-    agent: Any = None
+
+    def __init__(self, model_name: str):
+        super().__init__(model_name)
+        self.monologue = Monologue(self.model_name)
+        self.memory = LongTermMemory()
+
+    def _add_event(self, event: dict):
+        if 'output' in event['args'] and len(event['args']['output']) > MAX_OUTPUT_LENGTH:
+            event['args']['output'] = event['args']['output'][:MAX_OUTPUT_LENGTH] + "..."
+
+        self.monologue.add_event(event)
+        self.memory.add_event(event)
+        if self.monologue.get_total_length() > MAX_MONOLOGUE_LENGTH:
+            self.monologue.condense()
 
     def _initialize(self):
         if self._initialized:
             return
+
         if self.instruction is None or self.instruction == "":
             raise ValueError("Instruction must be provided")
-        self.agent = LangchainsAgentImpl(self.instruction, self.model_name)
+
         next_is_output = False
         for thought in INITIAL_THOUGHTS:
             thought = thought.replace("$TASK", self.instruction)
             if next_is_output:
-                event = Event("output", {"output": thought})
+                d = {"action": "output", "args": {"output": thought}}
                 next_is_output = False
             else:
                 if thought.startswith("RUN"):
                     command = thought.split("RUN ")[1]
-                    event = Event("run", {"command": command})
+                    d = {"action": "run", "args": {"command": command}}
                     next_is_output = True
+
                 elif thought.startswith("RECALL"):
                     query = thought.split("RECALL ")[1]
-                    event = Event("recall", {"query": query})
+                    d = {"action": "recall", "args": {"query": query}}
                     next_is_output = True
+
                 elif thought.startswith("BROWSE"):
                     url = thought.split("BROWSE ")[1]
-                    event = Event("browse", {"url": url})
+                    d = {"action": "browse", "args": {"url": url}}
                     next_is_output = True
                 else:
-                    event = Event("think", {"thought": thought})
-            self.agent.add_event(event)
-        self._initialized = True
+                    d = {"action": "think", "args": {"thought": thought}}
 
-    def add_event(self, event: Event) -> None:
-        if self.agent:
-            self.agent.add_event(event)
+        self._add_event(d)
+        self._initialized = True
 
-    def step(self, cmd_mgr) -> Event:
+    def step(self, state: State) -> Action:
         self._initialize()
-        return self.agent.get_next_action(cmd_mgr)
+        # TODO: make langchains agent use Action & Observation
+        # completly from ground up
 
-    def search_memory(self, query: str) -> List[str]:
-        return self.agent.memory.search(query)
+        # Translate state to action_dict
+        for prev_action, obs in state.updated_info:
+            if isinstance(obs, CmdOutputObservation):
+                if obs.error:
+                    d = {"action": "error", "args": {"output": obs.content}}
+                else:
+                    d = {"action": "output", "args": {"output": obs.content}}
+            # elif isinstance(obs, UserMessageObservation):
+            #     d = {"action": "output", "args": {"output": obs.message}}
+            # elif isinstance(obs, AgentMessageObservation):
+            #     d = {"action": "output", "args": {"output": obs.message}}
+            elif isinstance(obs, (BrowserOutputObservation, Observation)):
+                d = {"action": "output", "args": {"output": obs.content}}
+            else:
+                raise NotImplementedError(f"Unknown observation type: {obs}")
+            self._add_event(d)
 
-    def chat(self, message: str) -> None:
-        """
-        Optional method for interactive communication with the agent during its execution. Implementations
-        can use this method to modify the agent's behavior or state based on chat inputs.
 
-        Parameters:
-        - message (str): The chat message or command.
-        """
-        raise NotImplementedError
+            if isinstance(prev_action, CmdRunAction):
+                d = {"action": "run", "args": {"command": prev_action.command}}
+            elif isinstance(prev_action, CmdKillAction):
+                d = {"action": "kill", "args": {"id": prev_action.id}}
+            elif isinstance(prev_action, BrowseURLAction):
+                d = {"action": "browse", "args": {"url": prev_action.url}}
+            elif isinstance(prev_action, FileReadAction):
+                d = {"action": "read", "args": {"file": prev_action.path}}
+            elif isinstance(prev_action, FileWriteAction):
+                d = {"action": "write", "args": {"file": prev_action.path, "content": prev_action.contents}}
+            elif isinstance(prev_action, AgentRecallAction):
+                d = {"action": "recall", "args": {"query": prev_action.query}}
+            elif isinstance(prev_action, AgentThinkAction):
+                d = {"action": "think", "args": {"thought": prev_action.thought}}
+            elif isinstance(prev_action, AgentFinishAction):
+                d = {"action": "finish"}
+            else:
+                raise NotImplementedError(f"Unknown action type: {prev_action}")
+            self._add_event(d)
+
+        state.updated_info = []
+            
+        action_dict = llm.request_action(
+            self.instruction,
+            self.monologue.get_thoughts(),
+            self.model_name,
+            state.background_commands_obs,
+        )
+        if action_dict is None:
+            action_dict = {"action": "think", "args": {"thought": "..."}}
+
+        # Translate action_dict to Action
+        action = ACTION_TYPE_TO_CLASS[action_dict["action"]](**action_dict["args"])
+        self.latest_action = action
+        return action
+
+    def search_memory(self, query: str) -> List[str]:
+        return self.memory.search(query)
+
 
 Agent.register("LangchainsAgent", LangchainsAgent)

+ 0 - 37
agenthub/langchains_agent/utils/agent.py

@@ -1,37 +0,0 @@
-from agenthub.langchains_agent.utils.monologue import Monologue
-from agenthub.langchains_agent.utils.memory import LongTermMemory
-from opendevin.lib.event import Event
-import agenthub.langchains_agent.utils.llm as llm
-
-MAX_OUTPUT_LENGTH = 5000
-MAX_MONOLOGUE_LENGTH = 20000
-
-class Agent:
-    def __init__(self, task, model_name):
-        self.task = task
-        self.model_name = model_name
-        self.monologue = Monologue(model_name)
-        self.memory = LongTermMemory()
-
-    def add_event(self, event):
-        if 'output' in event.args and len(event.args['output']) > MAX_OUTPUT_LENGTH:
-            event.args['output'] = event.args['output'][:MAX_OUTPUT_LENGTH] + "..."
-        self.monologue.add_event(event)
-        self.memory.add_event(event)
-        if self.monologue.get_total_length() > MAX_MONOLOGUE_LENGTH:
-            self.monologue.condense()
-
-    def get_next_action(self, cmd_mgr):
-        action_dict = llm.request_action(
-            self.task,
-            self.monologue.get_thoughts(),
-            self.model_name,
-            cmd_mgr.background_commands
-        )
-        if action_dict is None:
-            # TODO: this seems to happen if the LLM response isn't valid JSON. Maybe it should be an `error` instead? How should we handle this case?
-            return Event('think', {'thought': '...'})
-        event = Event(action_dict['action'], action_dict['args'])
-        self.latest_action = event
-        return event
-

+ 44 - 26
agenthub/langchains_agent/utils/llm.py

@@ -4,11 +4,16 @@ from . import json
 
 if os.getenv("DEBUG"):
     from langchain.globals import set_debug
+
     set_debug(True)
 
 from typing import List
 from langchain_core.pydantic_v1 import BaseModel
 
+from opendevin.observation import (
+    CmdOutputObservation,
+)
+
 from langchain.chains import LLMChain
 from langchain.prompts import PromptTemplate
 from langchain_core.output_parsers import JsonOutputParser
@@ -88,60 +93,73 @@ The action key may be `summarize`, and `args.summary` should contain the summary
 You can also use the same action and args from the source monologue.
 """
 
-class Action(BaseModel):
+
+class _ActionDict(BaseModel):
     action: str
     args: dict
 
+
 class NewMonologue(BaseModel):
-    new_monologue: List[Action]
+    new_monologue: List[_ActionDict]
+
 
 def get_chain(template, model_name):
-    assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable to use langchains_agent."
-    llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), model_name=model_name)
+    assert (
+        "OPENAI_API_KEY" in os.environ
+    ), "Please set the OPENAI_API_KEY environment variable to use langchains_agent."
+    llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), model_name=model_name)  # type: ignore
     prompt = PromptTemplate.from_template(template)
     llm_chain = LLMChain(prompt=prompt, llm=llm)
     return llm_chain
 
-def summarize_monologue(thoughts, model_name):
+
+def summarize_monologue(thoughts: List[dict], model_name):
     llm_chain = get_chain(MONOLOGUE_SUMMARY_PROMPT, model_name)
     parser = JsonOutputParser(pydantic_object=NewMonologue)
-    resp = llm_chain.invoke({'monologue': json.dumps({'old_monologue': thoughts})})
+    resp = llm_chain.invoke({"monologue": json.dumps({"old_monologue": thoughts})})
+
     if os.getenv("DEBUG"):
         print("resp", resp)
-    parsed = parser.parse(resp['text'])
-    return parsed['new_monologue']
+    parsed = parser.parse(resp["text"])
+    return parsed["new_monologue"]
+
 
-def request_action(task, thoughts, model_name, background_commands=[]):
+def request_action(
+    task,
+    thoughts: List[dict],
+    model_name: str,
+    background_commands_obs: List[CmdOutputObservation] = [],
+):
     llm_chain = get_chain(ACTION_PROMPT, model_name)
-    parser = JsonOutputParser(pydantic_object=Action)
-    hint = ''
+    parser = JsonOutputParser(pydantic_object=_ActionDict)
+    hint = ""
     if len(thoughts) > 0:
         latest_thought = thoughts[-1]
-        if latest_thought.action == 'think':
-            if latest_thought.args['thought'].startswith("OK so my task is"):
+        if latest_thought["action"] == 'think':
+            if latest_thought["args"]['thought'].startswith("OK so my task is"):
                 hint = "You're just getting started! What should you do first?"
             else:
                 hint = "You've been thinking a lot lately. Maybe it's time to take action?"
-        elif latest_thought.action == 'error':
+        elif latest_thought["action"] == 'error':
             hint = "Looks like that last command failed. Maybe you need to fix it, or try something else."
 
     bg_commands_message = ""
-    if len(background_commands) > 0:
+    if len(background_commands_obs) > 0:
         bg_commands_message = "The following commands are running in the background:"
-        for id, command in background_commands.items():
-            bg_commands_message += f"\n`{id}`: {command.command}"
+        for command_obs in background_commands_obs:
+            bg_commands_message += f"\n`{command_obs.command_id}`: {command_obs.command}"
         bg_commands_message += "\nYou can end any process by sending a `kill` action with the numerical `id` above."
 
     latest_thought = thoughts[-1]
-    resp = llm_chain.invoke({
-        "monologue": json.dumps(thoughts),
-        "hint": hint,
-        "task": task,
-        "background_commands": bg_commands_message,
-    })
+    resp = llm_chain.invoke(
+        {
+            "monologue": json.dumps(thoughts),
+            "hint": hint,
+            "task": task,
+            "background_commands": bg_commands_message,
+        }
+    )
     if os.getenv("DEBUG"):
         print("resp", resp)
-    parsed = parser.parse(resp['text'])
+    parsed = parser.parse(resp["text"])
     return parsed
-
-

+ 2 - 2
agenthub/langchains_agent/utils/memory.py

@@ -18,9 +18,9 @@ class LongTermMemory:
     def add_event(self, event):
         doc = Document(
             text=json.dumps(event),
-            doc_id=self.thought_idx,
+            doc_id=str(self.thought_idx),
             extra_info={
-                "type": event.action,
+                "type": event["action"],
                 "idx": self.thought_idx,
             },
         )

+ 3 - 4
agenthub/langchains_agent/utils/monologue.py

@@ -1,6 +1,4 @@
 import agenthub.langchains_agent.utils.json as json
-from opendevin.lib.event import Event
-
 import agenthub.langchains_agent.utils.llm as llm
 
 class Monologue:
@@ -8,7 +6,7 @@ class Monologue:
         self.thoughts = []
         self.model_name = model_name
 
-    def add_event(self, t):
+    def add_event(self, t: dict):
         self.thoughts.append(t)
 
     def get_thoughts(self):
@@ -19,6 +17,7 @@ class Monologue:
 
     def condense(self):
         new_thoughts = llm.summarize_monologue(self.thoughts, self.model_name)
-        self.thoughts = [Event(t['action'], t['args']) for t in new_thoughts]
+        # self.thoughts = [Event(t['action'], t['args']) for t in new_thoughts]
+        self.thoughts = new_thoughts
 
 

+ 19 - 0
opendevin/action/__init__.py

@@ -0,0 +1,19 @@
+from .base import Action, NullAction
+from .bash import CmdRunAction, CmdKillAction
+from .browse import BrowseURLAction
+from .fileop import FileReadAction, FileWriteAction
+from .agent import AgentRecallAction, AgentThinkAction, AgentFinishAction, AgentEchoAction
+
+__all__ = [
+    "Action",
+    "NullAction",
+    "CmdRunAction",
+    "CmdKillAction",
+    "BrowseURLAction",
+    "FileReadAction",
+    "FileWriteAction",
+    "AgentRecallAction",
+    "AgentThinkAction",
+    "AgentFinishAction",
+    "AgentEchoAction",
+]

+ 58 - 0
opendevin/action/agent.py

@@ -0,0 +1,58 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from opendevin.observation import AgentRecallObservation, AgentMessageObservation, Observation
+from .base import ExecutableAction, NotExecutableAction
+if TYPE_CHECKING:
+    from opendevin.controller import AgentController
+
+
+@dataclass
+class AgentRecallAction(ExecutableAction):
+    query: str
+
+    def run(self, controller: "AgentController") -> AgentRecallObservation:
+        return AgentRecallObservation(
+            content="Recalling memories...",
+            memories=controller.agent.search_memory(self.query)
+        )
+
+    @property
+    def message(self) -> str:
+        return f"Recalling memories with query: {self.query}"
+
+
+@dataclass
+class AgentThinkAction(NotExecutableAction):
+    thought: str
+    runnable: bool = False
+
+    def run(self, controller: "AgentController") -> "Observation":
+        raise NotImplementedError
+
+    @property
+    def message(self) -> str:
+        return f"Thinking: {self.thought}"
+
+@dataclass
+class AgentEchoAction(ExecutableAction):
+    content: str
+    runnable: bool = True
+
+    def run(self, controller: "AgentController") -> "Observation":
+        return AgentMessageObservation(self.content)
+
+    @property
+    def message(self) -> str:
+        return f"Echoing: {self.content}"
+
+@dataclass
+class AgentFinishAction(NotExecutableAction):
+    runnable: bool = False
+
+    def run(self, controller: "AgentController") -> "Observation":
+        raise NotImplementedError
+
+    @property
+    def message(self) -> str:
+        return "Finished!"

+ 45 - 0
opendevin/action/base.py

@@ -0,0 +1,45 @@
+from dataclasses import dataclass
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from opendevin.controller import AgentController
+    from opendevin.observation import Observation
+
+@dataclass
+class Action:
+    def run(self, controller: "AgentController") -> "Observation":
+        raise NotImplementedError
+
+    def to_dict(self):
+        return {"action": self.__class__.__name__, "args": self.__dict__, "message": self.message}
+
+    @property
+    def executable(self) -> bool:
+        raise NotImplementedError
+
+    @property
+    def message(self) -> str:
+        raise NotImplementedError
+
+
+
+class ExecutableAction(Action):
+    @property
+    def executable(self) -> bool:
+        return True
+
+
+class NotExecutableAction(Action):
+    @property
+    def executable(self) -> bool:
+        return False
+
+class NullAction(NotExecutableAction):
+    """An action that does nothing.
+    This is used when the agent need to receive user follow-up messages from the frontend.
+    """
+
+    @property
+    def message(self) -> str:
+        return "No action"

+ 32 - 0
opendevin/action/bash.py

@@ -0,0 +1,32 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from .base import ExecutableAction
+
+if TYPE_CHECKING:
+    from opendevin.controller import AgentController
+    from opendevin.observation import CmdOutputObservation
+
+
+@dataclass
+class CmdRunAction(ExecutableAction):
+    command: str
+    background: bool = False
+
+    def run(self, controller: "AgentController") -> "CmdOutputObservation":
+        return controller.command_manager.run_command(self.command, self.background)
+
+    @property
+    def message(self) -> str:
+        return f"Running command: {self.command}"
+
+@dataclass
+class CmdKillAction(ExecutableAction):
+    id: int
+
+    def run(self, controller: "AgentController") -> "CmdOutputObservation":
+        return controller.command_manager.kill_command(self.id)
+
+    @property
+    def message(self) -> str:
+        return f"Killing command: {self.id}"

+ 21 - 0
opendevin/action/browse.py

@@ -0,0 +1,21 @@
+import requests
+
+from dataclasses import dataclass
+from opendevin.observation import BrowserOutputObservation
+
+from .base import ExecutableAction
+
+@dataclass
+class BrowseURLAction(ExecutableAction):
+    url: str
+
+    def run(self, *args, **kwargs) -> BrowserOutputObservation:
+        response = requests.get(self.url)
+        return BrowserOutputObservation(
+            content=response.text,
+            url=self.url
+        )
+
+    @property
+    def message(self) -> str:
+        return f"Browsing URL: {self.url}"

+ 47 - 0
opendevin/action/fileop.py

@@ -0,0 +1,47 @@
+import os
+from dataclasses import dataclass
+
+from opendevin.observation import Observation
+from .base import ExecutableAction
+
+# This is the path where the workspace is mounted in the container
+# The LLM sometimes returns paths with this prefix, so we need to remove it
+PATH_PREFIX = "/workspace/"
+
+def resolve_path(base_path, file_path):
+    if file_path.startswith(PATH_PREFIX):
+        file_path = file_path[len(PATH_PREFIX):]
+    return os.path.join(base_path, file_path)
+
+
+@dataclass
+class FileReadAction(ExecutableAction):
+    path: str
+    base_path: str = ""
+
+    def run(self, *args, **kwargs) -> Observation:
+        path = resolve_path(self.base_path, self.path)
+        with open(path, 'r') as file:
+            return Observation(file.read())
+
+    @property
+    def message(self) -> str:
+        return f"Reading file: {self.path}"
+
+
+@dataclass
+class FileWriteAction(ExecutableAction):
+    path: str
+    contents: str
+    base_path: str = ""
+
+    def run(self, *args, **kwargs) -> Observation:
+        path = resolve_path(self.base_path, self.path)
+        with open(path, 'w') as file:
+            file.write(self.contents)
+        return Observation(f"File written to {path}")
+
+    @property
+    def message(self) -> str:
+        return f"Writing file: {self.path}"
+

+ 11 - 74
opendevin/agent.py

@@ -1,40 +1,9 @@
 from abc import ABC, abstractmethod
-from typing import List, Dict, Type
-from dataclasses import dataclass
-from enum import Enum
+from typing import List, Dict, Type, TYPE_CHECKING
 
-from .lib.event import Event
-from .lib.command_manager import CommandManager
-
-class Role(Enum):
-    SYSTEM = "system"  # system message for LLM
-    USER = "user"  # the user
-    ASSISTANT = "assistant"  # the agent
-    ENVIRONMENT = "environment"  # the environment (e.g., bash shell, web browser, etc.)
-
-@dataclass
-class Message:
-    """
-    This data class represents a message sent by an agent to another agent or user.
-    """
-
-    role: Role
-    content: str
-    # TODO: add more fields as needed
-
-    def to_dict(self) -> Dict:
-        """
-        Converts the message to a dictionary (OpenAI chat-completion format).
-
-        Returns:
-        - message (Dict): A dictionary representation of the message.
-        """
-        role = self.role.value
-        content = self.content
-        if self.role == Role.ENVIRONMENT:
-            content = f"Environment Observation:\n{content}"
-            role = "user"  # treat environment messages as user messages
-        return {"role": role, "content": content}
+if TYPE_CHECKING:
+    from opendevin.action import Action
+    from opendevin.state import State
 
 
 class Agent(ABC):
@@ -45,26 +14,15 @@ class Agent(ABC):
     It tracks the execution status and maintains a history of interactions.
 
     :param instruction: The instruction for the agent to execute.
-    :param workspace_dir: The working directory for the agent.
     :param model_name: The litellm name of the model to use for the agent.
-    :param max_steps: The maximum number of steps to run the agent.
     """
 
-    _registry: Dict[str, Type['Agent']] = {}
+    _registry: Dict[str, Type["Agent"]] = {}
 
-    def __init__(
-        self,
-        workspace_dir: str,
-        model_name: str,
-        max_steps: int = 100
-    ):
-        self.instruction = ""
-        self.workspace_dir = workspace_dir
+    def __init__(self, model_name: str):
         self.model_name = model_name
-        self.max_steps = max_steps
-
+        self.instruction: str = ""  # need to be set before step
         self._complete = False
-        self._history: List[Message] = []
 
     @property
     def complete(self) -> bool:
@@ -76,28 +34,8 @@ class Agent(ABC):
         """
         return self._complete
 
-    @property
-    def history(self) -> List[Message]:
-        """
-        Provides the history of interactions or state changes since the instruction was initiated.
-
-        Returns:
-        - history (List[str]): A list of strings representing the history.
-        """
-        return self._history
-
     @abstractmethod
-    def add_event(self, event: Event) -> None:
-        """
-        Adds an event to the agent's history.
-
-        Parameters:
-        - event (Event): The event to add to the history.
-        """
-        pass
-
-    @abstractmethod
-    def step(self, cmd_mgr: CommandManager) -> Event:
+    def step(self, state: "State") -> "Action":
         """
         Starts the execution of the assigned instruction. This method should
         be implemented by subclasses to define the specific execution logic.
@@ -123,12 +61,11 @@ class Agent(ABC):
         to prepare the agent for restarting the instruction or cleaning up before destruction.
 
         """
-        self.instruction = ''
+        self.instruction = ""
         self._complete = False
-        self._history = []
 
     @classmethod
-    def register(cls, name: str, agent_cls: Type['Agent']):
+    def register(cls, name: str, agent_cls: Type["Agent"]):
         """
         Registers an agent class in the registry.
 
@@ -141,7 +78,7 @@ class Agent(ABC):
         cls._registry[name] = agent_cls
 
     @classmethod
-    def get_cls(cls, name: str) -> Type['Agent']:
+    def get_cls(cls, name: str) -> Type["Agent"]:
         """
         Retrieves an agent class from the registry.
 

+ 0 - 68
opendevin/controller.py

@@ -1,68 +0,0 @@
-import asyncio
-
-from opendevin.lib.command_manager import CommandManager
-from opendevin.lib.event import Event
-
-def print_callback(event):
-    print(event.str_truncated(), flush=True)
-
-class AgentController:
-    def __init__(self, agent, workdir, max_iterations=100, callbacks=[]):
-        self.agent = agent
-        self.max_iterations = max_iterations
-        self.background_commands = []
-        self.command_manager = CommandManager(workdir)
-        self.callbacks = callbacks
-        self.callbacks.append(self.agent.add_event)
-        self.callbacks.append(print_callback)
-
-    async def add_user_event(self, event: Event):
-        await self.handle_action(event)
-
-    async def start_loop(self, task):
-        try:
-            self.agent.instruction = task
-            for i in range(self.max_iterations):
-                print("STEP", i, flush=True)
-                done = await self.step()
-                if done:
-                    print("FINISHED", flush=True)
-                    break
-        except Exception as e:
-            print("Error in loop", e, flush=True)
-            pass
-
-
-    async def step(self) -> bool:
-        log_events = self.command_manager.get_background_events()
-        for event in log_events:
-            await self.run_callbacks(event)
-
-        try:
-            action_event = self.agent.step(self.command_manager)
-        except Exception as e:
-            action_event = Event('error', {'error': str(e)})
-        if action_event is None:
-            action_event = Event('error', {'error': "Agent did not return an event"})
-
-        await self.handle_action(action_event)
-        return action_event.action == 'finish'
-
-    async def handle_action(self, event: Event):
-        print("=== HANDLING EVENT ===", flush=True)
-        await self.run_callbacks(event)
-        print("---  EVENT OUTPUT  ---", flush=True)
-        output_event = event.run(self)
-        await self.run_callbacks(output_event)
-
-    async def run_callbacks(self, event):
-        if event is None:
-            return
-        for callback in self.callbacks:
-            idx = self.callbacks.index(callback)
-            try:
-                callback(event)
-            except Exception as e:
-                print("Callback error:" + str(idx), e, flush=True)
-                pass
-        await asyncio.sleep(0.001) # Give back control for a tick, so we can await in callbacks

+ 91 - 0
opendevin/controller/__init__.py

@@ -0,0 +1,91 @@
+import asyncio
+from typing import List, Callable, Tuple
+
+from opendevin.state import State
+from opendevin.agent import Agent
+from opendevin.action import (
+    Action,
+    NullAction,
+    FileReadAction,
+    FileWriteAction,
+    AgentFinishAction,
+)
+from opendevin.observation import (
+    Observation,
+    NullObservation
+)
+
+
+from .command_manager import CommandManager
+
+
+class AgentController:
+    def __init__(
+        self,
+        agent: Agent,
+        workdir: str,
+        max_iterations: int = 100,
+        callbacks: List[Callable] = [],
+    ):
+        self.agent = agent
+        self.max_iterations = max_iterations
+        self.workdir = workdir
+        self.command_manager = CommandManager(workdir)
+        self.callbacks = callbacks
+        self.state_updated_info: List[Tuple[Action, Observation]] = []
+
+    def get_current_state(self) -> State:
+        # update observations & actions
+        state = State(
+            background_commands_obs=self.command_manager.get_background_obs(),
+            updated_info=self.state_updated_info,
+        )
+        self.state_updated_info = []
+        return state
+
+    def add_observation(self, observation: Observation):
+        self.state_updated_info.append((NullAction(), observation))
+
+    async def start_loop(self, task_instruction: str):
+        try:
+            self.agent.instruction = task_instruction
+            for i in range(self.max_iterations):
+                print("STEP", i, flush=True)
+
+                state: State = self.get_current_state()
+                action: Action = self.agent.step(state)
+                
+                print("ACTION", action, flush=True)
+                for _callback_fn in self.callbacks:
+                    _callback_fn(action)
+                
+                if isinstance(action, AgentFinishAction):
+                    print("FINISHED", flush=True)
+                    break
+                if isinstance(action, (FileReadAction, FileWriteAction)):
+                    action_cls = action.__class__
+                    _kwargs = action.__dict__
+                    _kwargs["base_path"] = self.workdir
+                    action = action_cls(**_kwargs)
+                    print(action, flush=True)
+                print("---", flush=True)
+
+
+                if action.executable:
+                    observation: Observation = action.run(self)
+                else:
+                    print("ACTION NOT EXECUTABLE", flush=True)
+                    observation = NullObservation("")
+                print("OBSERVATION", observation, flush=True)
+                self.state_updated_info.append((action, observation))
+                
+                print(observation, flush=True)
+                for _callback_fn in self.callbacks:
+                    _callback_fn(observation)
+
+                print("==============", flush=True)
+
+                await asyncio.sleep(0.001)
+        except Exception as e:
+            print("Error in loop", e, flush=True)
+            pass

+ 66 - 0
opendevin/controller/command_manager.py

@@ -0,0 +1,66 @@
+from typing import List
+
+from opendevin.observation import CmdOutputObservation
+from opendevin.sandbox.sandbox import DockerInteractive
+
+
+class BackgroundCommand:
+    def __init__(self, id: int, command: str, dir: str):
+        self.command = command
+        self.id = id
+        self.shell = DockerInteractive(id=str(id), workspace_dir=dir)
+        self.shell.execute_in_background(command)
+
+    def get_logs(self) -> str:
+        # TODO: get an exit code if process is exited
+        return self.shell.read_logs()
+
+
+class CommandManager:
+    def __init__(self, dir):
+        self.cur_id = 0
+        self.directory = dir
+        self.background_commands = {}
+        self.shell = DockerInteractive(id="default", workspace_dir=dir)
+
+    def run_command(self, command: str, background=False) -> CmdOutputObservation:
+        if background:
+            return self._run_background(command)
+        else:
+            return self._run_immediately(command)
+
+    def _run_immediately(self, command: str) -> CmdOutputObservation:
+        exit_code, output = self.shell.execute(command)
+        return CmdOutputObservation(
+            content=output,
+            command_id=self.cur_id,
+            command=command,
+            exit_code=exit_code
+        )
+
+    def _run_background(self, command: str) -> CmdOutputObservation:
+        bg_cmd = BackgroundCommand(self.cur_id, command, self.directory)
+        self.cur_id += 1
+        self.background_commands[bg_cmd.id] = bg_cmd
+        return CmdOutputObservation(
+            content=f"Background command started.  To stop it, send a `kill` action with id {bg_cmd.id}",
+            command_id=bg_cmd.id,
+            command=command,
+            exit_code=0
+        )
+
+    def kill_command(self, id: int):
+        # TODO: get log events before killing
+        self.background_commands[id].shell.close()
+        del self.background_commands[id]
+
+    def get_background_obs(self) -> List[CmdOutputObservation]:
+        obs = []
+        for _id, cmd in self.background_commands.items():
+            output = cmd.get_logs()
+            obs.append(
+                CmdOutputObservation(
+                    content=output, command_id=_id, command=cmd.command
+                )
+            )
+        return obs

+ 0 - 5
opendevin/lib/actions/__init__.py

@@ -1,5 +0,0 @@
-from .browse import browse
-from .write import write
-from .read import read
-
-__all__ = ['run', 'kill', 'browse', 'write', 'read']

+ 0 - 6
opendevin/lib/actions/browse.py

@@ -1,6 +0,0 @@
-import requests
-
-def browse(url):
-    response = requests.get(url)
-    return response.text
-

+ 0 - 7
opendevin/lib/actions/read.py

@@ -1,7 +0,0 @@
-from .util import resolve_path
-
-def read(base_path, file_path):
-    file_path = resolve_path(base_path, file_path)
-    with open(file_path, 'r') as file:
-        return file.read()
-

+ 0 - 10
opendevin/lib/actions/util.py

@@ -1,10 +0,0 @@
-import os
-
-# This is the path where the workspace is mounted in the container
-# The LLM sometimes returns paths with this prefix, so we need to remove it
-PATH_PREFIX = "/workspace/"
-
-def resolve_path(base_path, file_path):
-    if file_path.startswith(PATH_PREFIX):
-        file_path = file_path[len(PATH_PREFIX):]
-    return os.path.join(base_path, file_path)

+ 0 - 8
opendevin/lib/actions/write.py

@@ -1,8 +0,0 @@
-from .util import resolve_path
-
-def write(base_path, file_path, contents):
-    file_path = resolve_path(base_path, file_path)
-    with open(file_path, 'w') as file:
-        file.write(contents)
-    return ""
-

+ 0 - 56
opendevin/lib/command_manager.py

@@ -1,56 +0,0 @@
-from typing import List
-
-from opendevin.lib.event import Event
-from opendevin.sandbox.sandbox import DockerInteractive
-
-class BackgroundCommand:
-    def __init__(self, id: int, command: str, dir: str):
-        self.command = command
-        self.id = id
-        self.shell = DockerInteractive(id=str(id), workspace_dir=dir)
-        self.shell.execute_in_background(command)
-
-    def get_logs(self):
-        # TODO: get an exit code if process is exited
-        return self.shell.read_logs()
-
-class CommandManager:
-    def __init__(self, dir):
-        self.cur_id = 0
-        self.directory = dir
-        self.background_commands = {}
-        self.shell = DockerInteractive(id="default", workspace_dir=dir)
-
-    def run_command(self, command: str, background=False) -> str:
-        if background:
-            return self.run_background(command)
-        else:
-            return self.run_immediately(command)
-
-    def run_immediately(self, command: str) -> str:
-        exit_code, output = self.shell.execute(command)
-        if exit_code != 0:
-            raise ValueError('Command failed with exit code ' + str(exit_code) + ': ' + output)
-        return output
-
-    def run_background(self, command: str) -> str:
-        bg_cmd = BackgroundCommand(self.cur_id, command, self.directory)
-        self.cur_id += 1
-        self.background_commands[bg_cmd.id] = bg_cmd
-        return "Background command started. To stop it, send a `kill` action with id " + str(bg_cmd.id)
-
-    def kill_command(self, id: int):
-        # TODO: get log events before killing
-        self.background_commands[id].shell.close()
-        del self.background_commands[id]
-
-    def get_background_events(self) -> List[Event]:
-        events = []
-        for id, cmd in self.background_commands.items():
-            output = cmd.get_logs()
-            events.append(Event('output', {
-                'output': output,
-                'id': id,
-                'command': cmd.command,
-            }))
-        return events

+ 0 - 93
opendevin/lib/event.py

@@ -1,93 +0,0 @@
-import opendevin.lib.actions as actions
-
-ACTION_TYPES = ['initialize', 'start', 'summarize', 'run', 'kill', 'browse', 'read', 'write', 'recall', 'think', 'output', 'error', 'finish']
-RUNNABLE_ACTIONS = ['run', 'kill', 'browse', 'read', 'write', 'recall']
-
-class Event:
-    def __init__(self, action, args, message=None):
-        if action not in ACTION_TYPES:
-            raise ValueError('Invalid action type: ' + action)
-        self.action = action
-        self.args = args
-        self.message = message
-
-    def __str__(self):
-        return self.action + " " + str(self.args)
-
-    def str_truncated(self, max_len=1000):
-        s = str(self)
-        if len(s) > max_len:
-            s = s[:max_len] + '...'
-        return s
-
-    def to_dict(self):
-        return {
-            'action': self.action,
-            'args': self.args
-        }
-
-    def get_message(self) -> str:
-        if self.message is not None:
-            return self.message
-        if self.action == 'run':
-            return 'Running command: ' + self.args['command']
-        elif self.action == 'kill':
-            return 'Killing command: ' + self.args['id']
-        elif self.action == 'browse':
-            return 'Browsing: ' + self.args['url']
-        elif self.action == 'read':
-            return 'Reading file: ' + self.args['path']
-        elif self.action == 'write':
-            return 'Writing to file: ' + self.args['path']
-        elif self.action == 'recall':
-            return 'Recalling memory: ' + self.args['query']
-        elif self.action == 'think':
-            return self.args['thought']
-        elif self.action == 'output':
-            return "Got output."
-        elif self.action == 'error':
-            return "Got an error: " + self.args['output']
-        elif self.action == 'finish':
-            return "Finished!"
-        else:
-            return ""
-
-    def is_runnable(self):
-        return self.action in RUNNABLE_ACTIONS
-
-    def run(self, agent_controller):
-        if not self.is_runnable():
-            return None
-        action = 'output'
-        try:
-            output = self._run_and_get_output(agent_controller)
-        except Exception as e:
-            output = 'Error: ' + str(e)
-            action = 'error'
-        out_event = Event(action, {'output': output})
-        return out_event
-
-    def _run_and_get_output(self, agent_controller) -> str:
-        if self.action == 'run':
-            cmd = self.args['command']
-            background = False
-            if 'background' in self.args and self.args['background']:
-                background = True
-            return agent_controller.command_manager.run_command(cmd, background)
-        if self.action == 'kill':
-            id = self.args['id']
-            return agent_controller.command_manager.kill_command(id)
-        elif self.action == 'browse':
-            url = self.args['url']
-            return actions.browse(url)
-        elif self.action == 'read':
-            path = self.args['path']
-            return actions.read(agent_controller.command_manager.directory, path)
-        elif self.action == 'write':
-            path = self.args['path']
-            contents = self.args['contents']
-            return actions.write(agent_controller.command_manager.directory, path, contents)
-        elif self.action == 'recall':
-            return agent_controller.agent.search_memory(self.args['query'])
-        else:
-            raise ValueError('Invalid action type')

+ 39 - 10
opendevin/main.py

@@ -1,25 +1,54 @@
-from typing import Type
 import asyncio
 import argparse
 
+from typing import Type
+
 import agenthub # noqa F401 (we import this to get the agents registered)
 from opendevin.agent import Agent
 from opendevin.controller import AgentController
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run an agent with a specific task")
-    parser.add_argument("-d", "--directory", required=True, type=str, help="The working directory for the agent")
-    parser.add_argument("-t", "--task", required=True, type=str, help="The task for the agent to perform")
-    parser.add_argument("-c", "--agent-cls", default="LangchainsAgent", type=str, help="The agent class to use")
-    parser.add_argument("-m", "--model-name", default="gpt-4-0125-preview", type=str, help="The (litellm) model name to use")
+    parser.add_argument(
+        "-d",
+        "--directory",
+        required=True,
+        type=str,
+        help="The working directory for the agent",
+    )
+    parser.add_argument(
+        "-t",
+        "--task",
+        required=True,
+        type=str,
+        help="The task for the agent to perform",
+    )
+    parser.add_argument(
+        "-c",
+        "--agent-cls",
+        default="LangchainsAgent",
+        type=str,
+        help="The agent class to use",
+    )
+    parser.add_argument(
+        "-m",
+        "--model-name",
+        default="gpt-4-0125-preview",
+        type=str,
+        help="The (litellm) model name to use",
+    )
+    parser.add_argument(
+        "-i",
+        "--max-iterations",
+        default=10,
+        type=int,
+        help="The maximum number of iterations to run the agent",
+    )
     args = parser.parse_args()
 
     print(f"Running agent {args.agent_cls} (model: {args.model_name}, directory: {args.directory}) with task: \"{args.task}\"")
 
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
-    agent = AgentCls(
-        workspace_dir=args.directory,
-        model_name=args.model_name
-    )
-    controller = AgentController(agent, args.directory)
+    agent = AgentCls(model_name=args.model_name)
+    controller = AgentController(agent, workdir=args.directory, max_iterations=args.max_iterations)
     asyncio.run(controller.start_loop(args.task))

+ 115 - 0
opendevin/observation.py

@@ -0,0 +1,115 @@
+import copy
+from typing import List
+from dataclasses import dataclass
+
+
+@dataclass
+class Observation:
+    """
+    This data class represents an observation of the environment.
+    """
+
+    content: str
+
+    def __str__(self) -> str:
+        return self.content
+
+    def to_dict(self) -> dict:
+        """Converts the observation to a dictionary."""
+        extras = copy.deepcopy(self.__dict__)
+        extras.pop("content", None)
+        return {
+            "observation": self.__class__.__name__,
+            "content": self.content,
+            "extras": extras,
+            "message": self.message,
+        }
+
+    @property
+    def message(self) -> str:
+        """Returns a message describing the observation."""
+        return "The agent made an observation."
+
+
+@dataclass
+class CmdOutputObservation(Observation):
+    """
+    This data class represents the output of a command.
+    """
+
+    command_id: int
+    command: str
+    exit_code: int = 0
+
+    @property
+    def error(self) -> bool:
+        return self.exit_code != 0
+
+    @property
+    def message(self) -> str:
+        return f'The agent observed command "{self.command}" executed with exit code {self.exit_code}.'
+
+
+@dataclass
+class BrowserOutputObservation(Observation):
+    """
+    This data class represents the output of a browser.
+    """
+
+    url: str
+
+    @property
+    def message(self) -> str:
+        return "The agent observed the browser output at URL."
+
+
+@dataclass
+class UserMessageObservation(Observation):
+    """
+    This data class represents a message sent by the user.
+    """
+
+    role: str = "user"
+
+    @property
+    def message(self) -> str:
+        return "The agent received a message from the user."
+
+
+@dataclass
+class AgentMessageObservation(Observation):
+    """
+    This data class represents a message sent by the agent.
+    """
+
+    role: str = "assistant"
+
+    @property
+    def message(self) -> str:
+        return "The agent received a message from itself."
+
+
+@dataclass
+class AgentRecallObservation(Observation):
+    """
+    This data class represents a list of memories recalled by the agent.
+    """
+
+    memories: List[str]
+    role: str = "assistant"
+
+    @property
+    def message(self) -> str:
+        return "The agent recalled memories."
+
+
+@dataclass
+class NullObservation(Observation):
+    """
+    This data class represents a null observation.
+    This is used when the produced action is NOT executable.
+    """
+
+    @property
+    def message(self) -> str:
+        return ""

+ 10 - 3
opendevin/sandbox/sandbox.py

@@ -48,7 +48,13 @@ class DockerInteractive:
         self.container_name = f"sandbox-{self.instance_id}"
 
         self.restart_docker_container()
-        self.execute('useradd --shell /bin/bash -u {uid} -o -c \"\" -m devin && su devin')
+        uid = os.getuid()
+        exit_code, logs = self.container.exec_run([
+            '/bin/bash', '-c',
+            f'useradd --shell /bin/bash -u {uid} -o -c \"\" -m devin'
+            ],
+            workdir="/workspace"
+        )
         # regester container cleanup function
         atexit.register(self.cleanup)
 
@@ -70,12 +76,13 @@ class DockerInteractive:
         return logs
 
     def execute(self, cmd: str) -> Tuple[int, str]:
-        exit_code, logs = self.container.exec_run(['/bin/bash', '-c', cmd], workdir="/workspace")
+        # TODO: each execute is not stateful! We need to keep track of the current working directory
+        exit_code, logs = self.container.exec_run(['su', 'devin', '-c', cmd], workdir="/workspace")
         return exit_code, logs.decode('utf-8')
 
     def execute_in_background(self, cmd: str) -> None:
         self.log_time = time.time()
-        result = self.container.exec_run(['/bin/bash', '-c', cmd], socket=True, workdir="/workspace")
+        result = self.container.exec_run(['su', 'devin', '-c', cmd], socket=True, workdir="/workspace")
         self.log_generator = result.output # socket.SocketIO
         self.log_generator._sock.setblocking(0)
 

+ 57 - 21
opendevin/server/session.py

@@ -1,14 +1,42 @@
 import os
 import asyncio
-from typing import Optional
+from typing import Optional, Dict, Type
 
 from fastapi import WebSocketDisconnect
 
 from opendevin.agent import Agent
 from opendevin.controller import AgentController
-from opendevin.lib.event import Event
 
-DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.getcwd())
+from opendevin.action import (
+    Action,
+    CmdRunAction,
+    CmdKillAction,
+    BrowseURLAction,
+    FileReadAction,
+    FileWriteAction,
+    AgentRecallAction,
+    AgentThinkAction,
+    AgentFinishAction,
+)
+from opendevin.observation import (
+    Observation,
+    UserMessageObservation
+)
+
+# NOTE: this is a temporary solution - but hopefully we can use Action/Observation throughout the codebase
+ACTION_TYPE_TO_CLASS: Dict[str, Type[Action]] = {
+    "run": CmdRunAction,
+    "kill": CmdKillAction,
+    "browse": BrowseURLAction,
+    "read": FileReadAction,
+    "write": FileWriteAction,
+    "recall": AgentRecallAction,
+    "think": AgentThinkAction,
+    "finish": AgentFinishAction,
+}
+
+
+DEFAULT_WORKSPACE_DIR = os.getenv("WORKSPACE_DIR", os.path.join(os.getcwd(), "workspace"))
 
 def parse_event(data):
     if "action" not in data:
@@ -20,7 +48,11 @@ def parse_event(data):
     message = None
     if "message" in data:
         message = data["message"]
-    return Event(action, args, message)
+    return {
+        "action": action,
+        "args": args,
+        "message": message,
+    }
 
 class Session:
     def __init__(self, websocket):
@@ -57,15 +89,21 @@ class Session:
                 if event is None:
                     await self.send_error("Invalid event")
                     continue
-                if event.action == "initialize":
+                if event["action"] == "initialize":
                     await self.create_controller(event)
-                elif event.action == "start":
+                elif event["action"] == "start":
                     await self.start_task(event)
                 else:
                     if self.controller is None:
                         await self.send_error("No agent started. Please wait a second...")
+
+                    elif event["action"] == "chat":
+                        self.controller.add_observation(UserMessageObservation(event["message"]))
                     else:
-                        await self.controller.add_user_event(event)
+                        # TODO: we only need to implement user message for now
+                        # since even Devin does not support having the user taking other
+                        # actions (e.g., edit files) while the agent is running
+                        raise NotImplementedError
 
         except WebSocketDisconnect as e:
             self.websocket = None
@@ -83,30 +121,28 @@ class Session:
         model = "gpt-4-0125-preview"
         if start_event and "model" in start_event.args:
             model = start_event.args["model"]
-
+        
+        if not os.path.exists(directory):
+            print(f"Workspace directory {directory} does not exist. Creating it...")
+            os.makedirs(directory)
+        directory = os.path.relpath(directory, os.getcwd())
+        
         AgentCls = Agent.get_cls(agent_cls)
-        self.agent = AgentCls(
-            workspace_dir=directory,
-            model_name=model,
-        )
+        self.agent = AgentCls(model_name=model)
         self.controller = AgentController(self.agent, directory, callbacks=[self.on_agent_event])
         await self.send({"action": "initialize", "message": "Control loop started."})
 
     async def start_task(self, start_event):
-        if "task" not in start_event.args:
+        if "task" not in start_event["args"]:
             await self.send_error("No task specified")
             return
         await self.send_message("Starting new task...")
-        task = start_event.args["task"]
+        task = start_event["args"]["task"]
         if self.controller is None:
             await self.send_error("No agent started. Please wait a second...")
             return
         self.agent_task = asyncio.create_task(self.controller.start_loop(task), name="agent loop")
 
-    def on_agent_event(self, event):
-        evt = {
-            "action": event.action,
-            "message": event.get_message(),
-            "args": event.args,
-        }
-        asyncio.create_task(self.send(evt), name="send event in callback")
+    def on_agent_event(self, event: Observation | Action):
+        event_dict = event.to_dict()
+        asyncio.create_task(self.send(event_dict), name="send event in callback")

+ 16 - 0
opendevin/state.py

@@ -0,0 +1,16 @@
+from dataclasses import dataclass
+from typing import List, Tuple
+
+from opendevin.action import (
+    Action,
+)
+from opendevin.observation import (
+    Observation,
+    CmdOutputObservation,
+)
+
+
+@dataclass
+class State:
+    background_commands_obs: List[CmdOutputObservation]
+    updated_info: List[Tuple[Action, Observation]]