Browse Source

refactor state management (#258)

* refactor state management

* rm import

* move task into state

* revert change

* revert a few files
Robert Brennan 2 years ago
parent
commit
94120f2b5d

+ 3 - 4
agenthub/codeact_agent/codeact_agent.py

@@ -67,16 +67,15 @@ class CodeActAgent(Agent):
         """
         super().__init__(llm)
         self.messages: List[Mapping[str, str]] = []
-        self.instruction: str = ""
 
     def step(self, state: State) -> Action:
         if len(self.messages) == 0:
-            assert self.instruction, "Expecting instruction to be set"
+            assert state.task, "Expecting instruction to be set"
             self.messages = [
                 {"role": "system", "content": SYSTEM_MESSAGE},
-                {"role": "user", "content": self.instruction},
+                {"role": "user", "content": state.task},
             ]
-            print(colored("===USER:===\n" + self.instruction, "green"))
+            print(colored("===USER:===\n" + state.task, "green"))
         updated_info = state.updated_info
         if updated_info:
             for prev_action, obs in updated_info:

+ 6 - 7
agenthub/langchains_agent/langchains_agent.py

@@ -1,9 +1,8 @@
 from typing import List
-
-from opendevin.llm.llm import LLM
 from opendevin.agent import Agent
 from opendevin.state import State
 from opendevin.action import Action
+from opendevin.llm.llm import LLM
 import agenthub.langchains_agent.utils.prompts as prompts
 from agenthub.langchains_agent.utils.monologue import Monologue
 from agenthub.langchains_agent.utils.memory import LongTermMemory
@@ -83,18 +82,18 @@ class LangchainsAgent(Agent):
         if self.monologue.get_total_length() > MAX_MONOLOGUE_LENGTH:
             self.monologue.condense(self.llm)
 
-    def _initialize(self):
+    def _initialize(self, task):
         if self._initialized:
             return
 
-        if self.instruction is None or self.instruction == "":
+        if task is None or task == "":
             raise ValueError("Instruction must be provided")
         self.monologue = Monologue()
         self.memory = LongTermMemory()
 
         next_is_output = False
         for thought in INITIAL_THOUGHTS:
-            thought = thought.replace("$TASK", self.instruction)
+            thought = thought.replace("$TASK", task)
             if next_is_output:
                 d = {"action": "output", "args": {"output": thought}}
                 next_is_output = False
@@ -120,7 +119,7 @@ class LangchainsAgent(Agent):
         self._initialized = True
 
     def step(self, state: State) -> Action:
-        self._initialize()
+        self._initialize(state.task)
         # TODO: make langchains agent use Action & Observation
         # completly from ground up
 
@@ -164,7 +163,7 @@ class LangchainsAgent(Agent):
         state.updated_info = []
 
         prompt = prompts.get_request_action_prompt(
-            self.instruction,
+            state.task,
             self.monologue.get_thoughts(),
             state.background_commands_obs,
         )

+ 0 - 5
opendevin/agent.py

@@ -12,9 +12,6 @@ class Agent(ABC):
     executing a specific instruction and allowing human interaction with the
     agent during execution.
     It tracks the execution status and maintains a history of interactions.
-
-    :param instruction: The instruction for the agent to execute.
-    :param model_name: The litellm name of the model to use for the agent.
     """
 
     _registry: Dict[str, Type["Agent"]] = {}
@@ -23,7 +20,6 @@ class Agent(ABC):
         self,
         llm: LLM,
     ):
-        self.instruction = ""
         self.llm = llm
         self._complete = False
 
@@ -64,7 +60,6 @@ class Agent(ABC):
         to prepare the agent for restarting the instruction or cleaning up before destruction.
 
         """
-        self.instruction = ""
         self._complete = False
 
     @classmethod

+ 17 - 15
opendevin/controller/__init__.py

@@ -1,5 +1,5 @@
 import asyncio
-from typing import List, Callable, Tuple
+from typing import List, Callable
 import traceback
 
 from opendevin.state import State
@@ -36,27 +36,26 @@ class AgentController:
         self.workdir = workdir
         self.command_manager = CommandManager(workdir)
         self.callbacks = callbacks
-        self.state_updated_info: List[Tuple[Action, Observation]] = []
 
-    def get_current_state(self) -> State:
-        # update observations & actions
-        state = State(
-            background_commands_obs=self.command_manager.get_background_obs(),
-            updated_info=self.state_updated_info,
-        )
-        self.state_updated_info = []
-        return state
+    def update_state_for_step(self, i):
+        self.state.iteration = i
+        self.state.background_commands_obs = self.command_manager.get_background_obs()
+
+    def update_state_after_step(self):
+        self.state.updated_info = []
 
     def add_history(self, action: Action, observation: Observation):
         if not isinstance(action, Action):
             raise ValueError("action must be an instance of Action")
         if not isinstance(observation, Observation):
             raise ValueError("observation must be an instance of Observation")
-        self.state_updated_info.append((action, observation))
+        self.state.history.append((action, observation))
+        self.state.updated_info.append((action, observation))
+
 
-    async def start_loop(self, task_instruction: str):
+    async def start_loop(self, task: str):
         finished = False
-        self.agent.instruction = task_instruction
+        self.state = State(task)
         for i in range(self.max_iterations):
             try:
                 finished = await self.step(i)
@@ -78,16 +77,19 @@ class AgentController:
             await self._run_callbacks(obs)
             print_with_indent("\nBACKGROUND LOG:\n%s" % obs)
 
-        state: State = self.get_current_state()
+        self.update_state_for_step(i)
         action: Action = NullAction()
         observation: Observation = NullObservation("")
         try:
-            action = self.agent.step(state)
+            action = self.agent.step(self.state)
+            if action is None:
+                raise ValueError("Agent must return an action")
             print_with_indent("\nACTION:\n%s" % action)
         except Exception as e:
             observation = AgentErrorObservation(str(e))
             print_with_indent("\nAGENT ERROR:\n%s" % observation)
             traceback.print_exc()
+        self.update_state_after_step()
 
         await self._run_callbacks(action)
 

+ 6 - 4
opendevin/state.py

@@ -1,4 +1,4 @@
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 from typing import List, Tuple
 
 from opendevin.action import (
@@ -9,8 +9,10 @@ from opendevin.observation import (
     CmdOutputObservation,
 )
 
-
 @dataclass
 class State:
-    background_commands_obs: List[CmdOutputObservation]
-    updated_info: List[Tuple[Action, Observation]]
+    task: str
+    iteration: int = 0
+    background_commands_obs: List[CmdOutputObservation] = field(default_factory=list)
+    history: List[Tuple[Action, Observation]] = field(default_factory=list)
+    updated_info: List[Tuple[Action, Observation]] = field(default_factory=list)