Просмотр исходного кода

Add ability to restore the cli session (optional) (#2699)

* add ability to restore the main session

* add quick log

* rename to cli session
Engel Nyst 1 год назад
Родитель
Сommit
2d9bb56763

+ 1 - 1
evaluation/EDA/run_infer.py

@@ -156,7 +156,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': {
         'test_result': {
             'success': test_result,
             'success': test_result,
             'final_message': final_message,
             'final_message': final_message,

+ 1 - 1
evaluation/agent_bench/run_infer.py

@@ -236,7 +236,7 @@ def process_instance(
         'metadata': metadata,
         'metadata': metadata,
         'history': histories,
         'history': histories,
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': {
         'test_result': {
             'agent_answer': agent_answer,
             'agent_answer': agent_answer,
             'final_answer': final_ans,
             'final_answer': final_ans,

+ 1 - 1
evaluation/biocoder/run_infer.py

@@ -243,7 +243,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
         'test_result': test_result,
     }
     }
 
 

+ 1 - 1
evaluation/bird/run_infer.py

@@ -245,7 +245,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
         'test_result': test_result,
     }
     }
     return output
     return output

+ 1 - 1
evaluation/gaia/run_infer.py

@@ -197,7 +197,7 @@ def process_instance(instance, agent_class, metadata, reset_logger: bool = True)
                 for action, obs in state.history
                 for action, obs in state.history
             ],
             ],
             'metrics': metrics,
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
             'test_result': test_result,
         }
         }
     except Exception:
     except Exception:

+ 1 - 1
evaluation/gorilla/run_infer.py

@@ -153,7 +153,7 @@ def process_instance(
                 for action, obs in state.history
                 for action, obs in state.history
             ],
             ],
             'metrics': metrics,
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
         }
         }
     except Exception:
     except Exception:
         logger.error('Process instance failed')
         logger.error('Process instance failed')

+ 1 - 1
evaluation/gpqa/run_infer.py

@@ -288,7 +288,7 @@ def process_instance(
                 for action, obs in state.history
                 for action, obs in state.history
             ],
             ],
             'metrics': metrics,
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
             'test_result': test_result,
         }
         }
 
 

+ 1 - 1
evaluation/humanevalfix/run_infer.py

@@ -241,7 +241,7 @@ def process_instance(
                 for action, obs in state.history
                 for action, obs in state.history
             ],
             ],
             'metrics': metrics,
             'metrics': metrics,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
             'test_result': test_result,
         }
         }
     except Exception:
     except Exception:

+ 1 - 1
evaluation/logic_reasoning/run_infer.py

@@ -263,7 +263,7 @@ def process_instance(
             'metrics': metrics,
             'metrics': metrics,
             'final_message': final_message,
             'final_message': final_message,
             'messages': messages,
             'messages': messages,
-            'error': state.error if state and state.error else None,
+            'error': state.last_error if state and state.last_error else None,
             'test_result': test_result,
             'test_result': test_result,
         }
         }
     except Exception:
     except Exception:

+ 1 - 1
evaluation/miniwob/run_infer.py

@@ -100,7 +100,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': reward,
         'test_result': reward,
     }
     }
 
 

+ 1 - 1
evaluation/mint/run_infer.py

@@ -185,7 +185,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': task_state.success if task_state else False,
         'test_result': task_state.success if task_state else False,
     }
     }
 
 

+ 1 - 1
evaluation/swe_bench/run_infer.py

@@ -343,7 +343,7 @@ IMPORTANT TIPS:
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': test_result,
         'test_result': test_result,
     }
     }
 
 

+ 1 - 1
evaluation/toolqa/run_infer.py

@@ -139,7 +139,7 @@ def process_instance(task, agent_class, metadata, reset_logger: bool = True):
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
     }
     }
     return output
     return output
 
 

+ 1 - 1
evaluation/webarena/run_infer.py

@@ -100,7 +100,7 @@ def process_instance(
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
             (event_to_dict(action), event_to_dict(obs)) for action, obs in state.history
         ],
         ],
         'metrics': metrics,
         'metrics': metrics,
-        'error': state.error if state and state.error else None,
+        'error': state.last_error if state and state.last_error else None,
         'test_result': reward,
         'test_result': reward,
     }
     }
 
 

+ 22 - 10
opendevin/controller/agent_controller.py

@@ -74,14 +74,19 @@ class AgentController:
         self._step_lock = asyncio.Lock()
         self._step_lock = asyncio.Lock()
         self.id = sid
         self.id = sid
         self.agent = agent
         self.agent = agent
-        if initial_state is None:
-            self.state = State(inputs={}, max_iterations=max_iterations)
-        else:
-            self.state = initial_state
+
+        # subscribe to the event stream
         self.event_stream = event_stream
         self.event_stream = event_stream
         self.event_stream.subscribe(
         self.event_stream.subscribe(
             EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
             EventStreamSubscriber.AGENT_CONTROLLER, self.on_event, append=is_delegate
         )
         )
+
+        # state from the previous session, state from a parent agent, or a fresh state
+        self._set_initial_state(
+            state=initial_state,
+            max_iterations=max_iterations,
+        )
+
         self.max_budget_per_task = max_budget_per_task
         self.max_budget_per_task = max_budget_per_task
         if not is_delegate:
         if not is_delegate:
             self.agent_task = asyncio.create_task(self._start_step_loop())
             self.agent_task = asyncio.create_task(self._start_step_loop())
@@ -108,9 +113,9 @@ class AgentController:
         - the string message should be user-friendly, it will be shown in the UI
         - the string message should be user-friendly, it will be shown in the UI
         - an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
         - an ErrorObservation can be sent to the LLM by the agent, with the exception message, so it can self-correct next time
         """
         """
-        self.state.error = message
+        self.state.last_error = message
         if exception:
         if exception:
-            self.state.error += f': {exception}'
+            self.state.last_error += f': {exception}'
         await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
         await self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)
 
 
     async def add_history(self, action: Action, observation: Observation):
     async def add_history(self, action: Action, observation: Observation):
@@ -182,7 +187,7 @@ class AgentController:
         self.agent.reset()
         self.agent.reset()
 
 
     async def set_agent_state_to(self, new_state: AgentState):
     async def set_agent_state_to(self, new_state: AgentState):
-        logger.info(
+        logger.debug(
             f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
             f'[Agent Controller {self.id}] Setting agent({self.agent.name}) state from {self.state.agent_state} to {new_state}'
         )
         )
 
 
@@ -235,7 +240,7 @@ class AgentController:
             return
             return
 
 
         if self._pending_action:
         if self._pending_action:
-            logger.info(
+            logger.debug(
                 f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
                 f'[Agent Controller {self.id}] waiting for pending action: {self._pending_action}'
             )
             )
             await asyncio.sleep(1)
             await asyncio.sleep(1)
@@ -336,8 +341,15 @@ class AgentController:
     def get_state(self):
     def get_state(self):
         return self.state
         return self.state
 
 
-    def set_state(self, state: State):
-        self.state = state
+    def _set_initial_state(
+        self, state: State | None, max_iterations: int = MAX_ITERATIONS
+    ):
+        # state from the previous session, state from a parent agent, or a new state
+        # note that this is called twice when restoring a previous session, first with state=None
+        if state is None:
+            self.state = State(inputs={}, max_iterations=max_iterations)
+        else:
+            self.state = state
 
 
     def _is_stuck(self):
     def _is_stuck(self):
         # check if delegate stuck
         # check if delegate stuck

+ 1 - 1
opendevin/controller/state/state.py

@@ -34,7 +34,7 @@ class State:
     updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
     updated_info: list[tuple[Action, Observation]] = field(default_factory=list)
     inputs: dict = field(default_factory=dict)
     inputs: dict = field(default_factory=dict)
     outputs: dict = field(default_factory=dict)
     outputs: dict = field(default_factory=dict)
-    error: str | None = None
+    last_error: str | None = None
     agent_state: AgentState = AgentState.LOADING
     agent_state: AgentState = AgentState.LOADING
     resume_state: AgentState | None = None
     resume_state: AgentState | None = None
     metrics: Metrics = Metrics()
     metrics: Metrics = Metrics()

+ 2 - 0
opendevin/core/config.py

@@ -154,6 +154,7 @@ class AppConfig(metaclass=Singleton):
         sandbox_timeout: The timeout for the sandbox.
         sandbox_timeout: The timeout for the sandbox.
         debug: Whether to enable debugging.
         debug: Whether to enable debugging.
         enable_auto_lint: Whether to enable auto linting. This is False by default, for regular runs of the app. For evaluation, please set this to True.
         enable_auto_lint: Whether to enable auto linting. This is False by default, for regular runs of the app. For evaluation, please set this to True.
+        enable_cli_session: Whether to enable saving and restoring the session when run from CLI.
         file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
         file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit.
         file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
         file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False.
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
@@ -195,6 +196,7 @@ class AppConfig(metaclass=Singleton):
     enable_auto_lint: bool = (
     enable_auto_lint: bool = (
         False  # once enabled, OpenDevin would lint files after editing
         False  # once enabled, OpenDevin would lint files after editing
     )
     )
+    enable_cli_session: bool = False
     file_uploads_max_file_size_mb: int = 0
     file_uploads_max_file_size_mb: int = 0
     file_uploads_restrict_file_types: bool = False
     file_uploads_restrict_file_types: bool = False
     file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])
     file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*'])

+ 42 - 8
opendevin/core/main.py

@@ -1,13 +1,13 @@
 import asyncio
 import asyncio
 import os
 import os
 import sys
 import sys
-from typing import Callable, Optional, Type
+from typing import Callable, Type
 
 
 import agenthub  # noqa F401 (we import this to get the agents registered)
 import agenthub  # noqa F401 (we import this to get the agents registered)
 from opendevin.controller import AgentController
 from opendevin.controller import AgentController
 from opendevin.controller.agent import Agent
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import args, get_llm_config_arg
+from opendevin.core.config import args, config, get_llm_config_arg
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import AgentState
 from opendevin.core.schema import AgentState
 from opendevin.events import EventSource, EventStream, EventStreamSubscriber
 from opendevin.events import EventSource, EventStream, EventStreamSubscriber
@@ -33,11 +33,11 @@ def read_task_from_stdin() -> str:
 async def main(
 async def main(
     task_str: str = '',
     task_str: str = '',
     exit_on_message: bool = False,
     exit_on_message: bool = False,
-    fake_user_response_fn: Optional[Callable[[Optional[State]], str]] = None,
-    sandbox: Optional[Sandbox] = None,
-    runtime_tools_config: Optional[dict] = None,
+    fake_user_response_fn: Callable[[State | None], str] | None = None,
+    sandbox: Sandbox | None = None,
+    runtime_tools_config: dict | None = None,
     sid: str | None = None,
     sid: str | None = None,
-) -> Optional[State]:
+) -> State | None:
     """Main coroutine to run the agent controller with task input flexibility.
     """Main coroutine to run the agent controller with task input flexibility.
     It's only used when you launch opendevin backend directly via cmdline.
     It's only used when you launch opendevin backend directly via cmdline.
 
 
@@ -82,16 +82,33 @@ async def main(
         )
         )
         llm = LLM(args.model_name)
         llm = LLM(args.model_name)
 
 
+    # set up the agent
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
     agent = AgentCls(llm=llm)
     agent = AgentCls(llm=llm)
 
 
-    event_stream = EventStream('main' + ('_' + sid if sid else ''))
+    # set up the event stream
+    cli_session = 'main' + ('_' + sid if sid else '')
+    event_stream = EventStream(cli_session)
+
+    # restore cli session if enabled
+    initial_state = None
+    if config.enable_cli_session:
+        try:
+            logger.info('Restoring agent state from cli session')
+            initial_state = State.restore_from_session(cli_session)
+        except Exception as e:
+            print('Error restoring state', e)
+
+    # init controller with this initial state
     controller = AgentController(
     controller = AgentController(
         agent=agent,
         agent=agent,
         max_iterations=args.max_iterations,
         max_iterations=args.max_iterations,
         max_budget_per_task=args.max_budget_per_task,
         max_budget_per_task=args.max_budget_per_task,
         event_stream=event_stream,
         event_stream=event_stream,
+        initial_state=initial_state,
     )
     )
+
+    # runtime and tools
     runtime = ServerRuntime(event_stream=event_stream, sandbox=sandbox)
     runtime = ServerRuntime(event_stream=event_stream, sandbox=sandbox)
     runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
     runtime.init_sandbox_plugins(controller.agent.sandbox_plugins)
     runtime.init_runtime_tools(
     runtime.init_runtime_tools(
@@ -110,7 +127,18 @@ async def main(
             task = f.read()
             task = f.read()
             logger.info(f'Dynamic Eval task: {task}')
             logger.info(f'Dynamic Eval task: {task}')
 
 
-    await event_stream.add_event(MessageAction(content=task), EventSource.USER)
+    # start event is a MessageAction with the task, either resumed or new
+    if config.enable_cli_session and initial_state is not None:
+        # we're resuming the previous session
+        await event_stream.add_event(
+            MessageAction(
+                content="Let's get back on track. If you experienced errors before, do NOT resume your task. Ask me about it."
+            ),
+            EventSource.USER,
+        )
+    elif initial_state is None:
+        # init with the provided task
+        await event_stream.add_event(MessageAction(content=task), EventSource.USER)
 
 
     async def on_event(event: Event):
     async def on_event(event: Event):
         if isinstance(event, AgentStateChangedObservation):
         if isinstance(event, AgentStateChangedObservation):
@@ -134,6 +162,12 @@ async def main(
     ]:
     ]:
         await asyncio.sleep(1)  # Give back control for a tick, so the agent can run
         await asyncio.sleep(1)  # Give back control for a tick, so the agent can run
 
 
+    # save session when we're about to close
+    if config.enable_cli_session:
+        end_state = controller.get_state()
+        end_state.save_to_session(cli_session)
+
+    # close when done
     await controller.close()
     await controller.close()
     runtime.close()
     runtime.close()
     return controller.get_state()
     return controller.get_state()

+ 1 - 1
opendevin/server/session/agent.py

@@ -111,7 +111,7 @@ class AgentSession:
         )
         )
         try:
         try:
             agent_state = State.restore_from_session(self.sid)
             agent_state = State.restore_from_session(self.sid)
-            self.controller.set_state(agent_state)
+            self.controller._set_initial_state(agent_state)
             logger.info(f'Restored agent state from session, sid: {self.sid}')
             logger.info(f'Restored agent state from session, sid: {self.sid}')
         except Exception as e:
         except Exception as e:
             print('Error restoring state', e)
             print('Error restoring state', e)

+ 6 - 6
tests/integration/test_agent.py

@@ -40,7 +40,7 @@ def test_write_simple_script():
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
 
     # Verify the script file exists
     # Verify the script file exists
     script_path = os.path.join(workspace_base, 'hello.sh')
     script_path = os.path.join(workspace_base, 'hello.sh')
@@ -86,7 +86,7 @@ def test_edits():
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
 
     # Verify bad.txt has been fixed
     # Verify bad.txt has been fixed
     text = """This is a stupid typo.
     text = """This is a stupid typo.
@@ -112,7 +112,7 @@ def test_ipython():
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
 
     # Verify the file exists
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -140,7 +140,7 @@ def test_simple_task_rejection():
     task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
     task = 'Write a git commit message for the current staging area. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task))
     final_state: State = asyncio.run(main(task))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
     assert isinstance(final_state.history[-1][0], AgentRejectAction)
     assert isinstance(final_state.history[-1][0], AgentRejectAction)
 
 
 
 
@@ -157,7 +157,7 @@ def test_ipython_module():
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
 
 
     # Verify the file exists
     # Verify the file exists
     file_path = os.path.join(workspace_base, 'test.txt')
     file_path = os.path.join(workspace_base, 'test.txt')
@@ -186,6 +186,6 @@ def test_browse_internet(http_server):
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     final_state: State = asyncio.run(main(task, exit_on_message=True))
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
-    assert final_state.error is None
+    assert final_state.last_error is None
     assert isinstance(final_state.history[-1][0], AgentFinishAction)
     assert isinstance(final_state.history[-1][0], AgentFinishAction)
     assert 'OpenDevin is all you need!' in str(final_state.history)
     assert 'OpenDevin is all you need!' in str(final_state.history)