Просмотр исходного кода

Add ability to restore the cli session (optional) (#2699)

* add ability to restore the main session

* add quick log

* rename to cli session
Engel Nyst 1 год назад
Родитель
Сommit
2d9bb56763

+ 1 - 1
evaluation/EDA/run_infer.py

@@ -156,7 +156,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': {
             'success': test_result,
             'final_message': final_message,

+ 1 - 1
evaluation/agent_bench/run_infer.py

@@ -236,7 +236,7 @@ def process_instance(
         'metadata': metadata,
         'history': histories,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': {
             'agent_answer': agent_answer,
             'final_answer': final_ans,

+ 1 - 1
evaluation/biocoder/run_infer.py

@@ -243,7 +243,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
     }
 

+ 1 - 1
evaluation/bird/run_infer.py

@@ -245,7 +245,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
     }
     return output

+ 1 - 1
evaluation/gaia/run_infer.py

@@ -197,7 +197,7 @@ def process_instance(instance, agent_class, metadata, reset_logger: bool = True)
                 for action, obs in state.history
             ],
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
         }
     except Exception:

+ 1 - 1
evaluation/gorilla/run_infer.py

@@ -153,7 +153,7 @@ def process_instance(
                 for action, obs in state.history
             ],
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
         }
     except Exception:
         logger.error('Process instance failed')

+ 1 - 1
evaluation/gpqa/run_infer.py

@@ -288,7 +288,7 @@ def process_instance(
                 for action, obs in state.history
             ],
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
         }
 

+ 1 - 1
evaluation/humanevalfix/run_infer.py

@@ -241,7 +241,7 @@ def process_instance(
                 for action, obs in state.history
             ],
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
         }
     except Exception:

+ 1 - 1
evaluation/logic_reasoning/run_infer.py

@@ -263,7 +263,7 @@ def process_instance(
             'metrics': metrics,
             'final_message': final_message,
             'messages': messages,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
         }
     except Exception:

+ 1 - 1
evaluation/miniwob/run_infer.py

@@ -100,7 +100,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': reward,
     }
 

+ 1 - 1
evaluation/mint/run_infer.py

@@ -185,7 +185,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': task_state.success if task_state else False,
     }
 

+ 1 - 1
evaluation/swe_bench/run_infer.py

@@ -343,7 +343,7 @@ IMPORTANT TIPS:
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
     }
 

+ 1 - 1
evaluation/toolqa/run_infer.py

@@ -139,7 +139,7 @@ def process_instance(task, agent_class, metadata, reset_logger: bool = True):
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
     }
     return output
 

+ 1 - 1
evaluation/webarena/run_infer.py

@@ -100,7 +100,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': reward,
     }
 

+ 22 - 10
opendevin/controller/agent_controller.py

@@ -74,14 +74,19 @@ class AgentController:
         self._step_lock = asyncio.Lock()
         self.id = sid
         self.agent = agent
-        if initial_state is None:
-            self.state = State(inputs={}, max_iterations=max_iterations)
-        else:
-            self.state = initial_state
+
+        # subscribe to the event stream
         self.event_stream = event_stream
         self.event_stream.subscribe(
             EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
         )
+
+        # state from the previous session, state from a parent agent, or a fresh state
+        self._set_initial_state(
+            state=initial_state,
+            max_iterations=max_iterations,
+        )
+
         self.max_budget_per_task = max_budget_per_task
         if not is_delegate:
             self.agent_task = asyncio.create_task(self._start_step_loop())
@@ -108,9 +113,9 @@ class AgentController:
         - the string message should be user-friendly, it will be shown in the UI
         - an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
         """
-        self.state.error = message
+        self.state.last_error = message
         if exception:
-            self.state.error += f': {exception}'
+            self.state.last_error += f': {exception}'
         await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
 
     async def add_history(self, action: Action, observation: Observation):
@@ -182,7 +187,7 @@ class AgentController:
         self.agent.reset()
 
     async def set_agent_state_to(self, new_state: AgentState):
-        logger.info(
+        logger.debug(
             f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
         )
 
@@ -235,7 +240,7 @@ class AgentController:
             return
 
         if self._pending_action:
-            logger.info(
+            logger.debug(
                 f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
             )
             await asyncio.sleep(1)
@@ -336,8 +341,15 @@ class AgentController:
     def get_state(self):
         return self.state
 
-    def set_state(self, state: State):
-        self.state = state
+    def _set_initial_state(
+        self, state: State | None, max_iterations: int = MAX_ITERATIONS
+    ):
+        # state from the previous session, state from a parent agent, or a new state
+        # note that this is called twice when restoring a previous session, first with state=None
+        if state is None:
+            self.state = State(inputs={}, max_iterations=max_iterations)
+        else:
+            self.state = state
 
     def _is_stuck(self):
         # check if delegate stuck

+ 1 - 1
opendevin/controller/state/state.py

@@ -34,7 +34,7 @@ class State:
     updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
     inputs: dict = field(default_factory=dict)
     outputs: dict = field(default_factory=dict)
-    error: str | None = None
+    last_error: str | None = None
     agent_state: AgentState = AgentState.LOADING
     resume_state: AgentState | None = None
     metrics: Metrics = Metrics()

+ 2 - 0
opendevin/core/config.py

@@ -154,6 +154,7 @@ class AppConfig(metaclass=Singleton):
         sandbox_timeout: The timeout for the sandbox.
         debug: Whether to enable debugging.
         enable_auto_lint: Whether to enable auto linting. This is False by default, for regular runs of the app. For evaluation, please set this to True.
+        enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
         file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
         file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
@@ -195,6 +196,7 @@ class AppConfig(metaclass=Singleton):
     enable_auto_lint: bool = (
         False  # once enabled, OpenDevin would lint files after editing
     )
+    enable_cli_session: bool = False
     file_uploads_max_file_size_mb: int = 0
     file_uploads_restrict_file_types: bool = False
     file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])

+ 42 - 8
opendevin/core/main.py

@@ -1,13 +1,13 @@
 import asyncio
 import os
 import sys
-from typing import Callable, Optional, Type
+from typing import Callable, Type
 
 import agenthub  # noqa F401 (we import this to get the agents registered)
 from opendevin.controller import AgentController
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import args, get_llm_config_arg
+from opendevin.core.config import args, config, get_llm_config_arg
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import AgentState
 from opendevin.events import EventSource, EventStream, EventStreamSubscriber
@@ -33,11 +33,11 @@ def read_task_from_stdin() -> str:
 async def main(
     task_str: str = '',
     exit_on_message: bool = False,
-    fake_user_response_fn: Optional[Callable[[Optional[State]], str]] = None,
-    sandbox: Optional[Sandbox] = None,
-    runtime_tools_config: Optional[dict] = None,
+    fake_user_response_fn: Callable[[State | None], str] | None = None,
+    sandbox: Sandbox | None = None,
+    runtime_tools_config: dict | None = None,
     sid: str | None = None,
-) -> Optional[State]:
+) -> State | None:
     """Main coroutine to run the agent controller with task input flexibility.
     It's only used when you launch opendevin backend directly via cmdline.
 
@@ -82,16 +82,33 @@ async def main(
         )
         llm = LLM(args.model_name)
 
+    # set up the agent
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
     agent = AgentCls(llm=llm)
 
-    event_stream = EventStream('main' + ('_' + sid if sid else ''))
+    # set up the event stream
+    cli_session = 'main' + ('_' + sid if sid else '')
+    event_stream = EventStream(cli_session)
+
+    # restore cli session if enabled
+    initial_state = None
+    if config.enable_cli_session:
+        try:
+            logger.info('Restoring agent state from cli session')
+            initial_state = State.restore_from_session(cli_session)
+        except Exception as e:
+            print('Error restoring state', e)
+
+    # init controller with this initial state
     controller = AgentController(
         agent=agent,
         max_iterations=args.max_iterations,
         max_budget_per_task=args.max_budget_per_task,
         event_stream=event_stream,
+        initial_state=initial_state,
     )
+
+    # runtime and tools
     runtime = ServerRuntime(event_stream=event_stream, sandbox=sandbox)
     runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
     runtime.init_runtime_tools(
@@ -110,7 +127,18 @@ async def main(
             task = f.read()
             logger.info(f'Dynamic Eval task: {task}')
 
-    await event_stream.add_event(MessageAction(content=task), EventSource.USER)
+    # start event is a MessageAction with the task, either resumed or new
+    if config.enable_cli_session and initial_state is not None:
+        # we're resuming the previous session
+        await event_stream.add_event(
+            MessageAction(
+                content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
+            ),
+            EventSource.USER,
+        )
+    elif initial_state is None:
+        # init with the provided task
+        await event_stream.add_event(MessageAction(content=task), EventSource.USER)
 
     async def on_event(event: Event):
         if isinstance(event, AgentStateChangedObservation):
@@ -134,6 +162,12 @@ async def main(
     ]:
         await asyncio.sleep(1)  # Give back control for a tick, so the agent can run
 
+    # save session when we're about to close
+    if config.enable_cli_session:
+        end_state = controller.get_state()
+        end_state.save_to_session(cli_session)
+
+    # close when done
     await controller.close()
     runtime.close()
     return controller.get_state()

+ 1 - 1
opendevin/server/session/agent.py

@@ -111,7 +111,7 @@ class AgentSession:
         )
         try:
             agent_state = State.restore_from_session(self.sid)
-            self.controller.set_state(agent_state)
+            self.controller._set_initial_state(agent_state)
             logger.info(f'Restored agent state from session, sid: {self.sid}')
         except Exception as e:
             print('Error restoring state', e)

+ 6 - 6
tests/integration/test_agent.py

@@ -40,7 +40,7 @@ def test_write_simple_script():
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
     # Verify the script file exists
     script_path = os.path.join(workspace_base, 'hello.sh')
@@ -86,7 +86,7 @@ def test_edits():
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
     # Verify bad.txt has been fixed
     text = """This is a stupid typo.
@@ -112,7 +112,7 @@ def test_ipython():
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -140,7 +140,7 @@ def test_simple_task_rejection():
     task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
     assert isinstance(final_state.history[-1][0], AgentRejectAction)
 
 
@@ -157,7 +157,7 @@ def test_ipython_module():
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -186,6 +186,6 @@ def test_browse_internet(http_server):
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
     assert isinstance(final_state.history[-1][0], AgentFinishAction)
     assert 'OpenDevin is all you need!' in str(final_state.history)