Bläddra i källkod

Fix restore cli sessions (#3409)

* fix restore cli sessions

* pytest

* fix log message

* make sure sid is set

---------

Co-authored-by: mamoodi <mamoodiha@gmail.com>
Engel Nyst 1 år sedan
förälder
incheckning
9cb0bf97c1

+ 12 - 1
opendevin/core/config.py

@@ -665,7 +665,11 @@ def get_parser() -> argparse.ArgumentParser:
         help='The working directory for the agent',
     )
     parser.add_argument(
-        '-t', '--task', type=str, default='', help='The task for the agent to perform'
+        '-t',
+        '--task',
+        type=str,
+        default='',
+        help='The task for the agent to perform',
     )
     parser.add_argument(
         '-f',
@@ -725,6 +729,13 @@ def get_parser() -> argparse.ArgumentParser:
         type=str,
         help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
     )
+    parser.add_argument(
+        '-n',
+        '--name',
+        default='default',
+        type=str,
+        help='Name for the session',
+    )
     return parser
 
 

+ 26 - 3
opendevin/core/main.py

@@ -1,4 +1,5 @@
 import asyncio
+import hashlib
 import sys
 import uuid
 from typing import Callable, Type
@@ -47,9 +48,13 @@ async def create_runtime(
     sid: The session id.
     runtime_tools_config: (will be deprecated) The runtime tools config.
     """
+    # if sid is provided on the command line, use it as the name of the event stream
+    # otherwise generate it on the basis of the configured jwt_secret
+    # we can do this better, this is just so that the sid is retrieved when we want to restore the session
+    session_id = sid or generate_sid(config)
+
     # set up the event stream
     file_store = get_file_store(config.file_store, config.file_store_path)
-    session_id = 'main' + ('_' + sid if sid else str(uuid.uuid4()))
     event_stream = EventStream(session_id, file_store)
 
     # agent class
@@ -72,6 +77,7 @@ async def create_runtime(
 async def run_controller(
     config: AppConfig,
     task_str: str,
+    sid: str | None = None,
     runtime: Runtime | None = None,
     agent: Agent | None = None,
     exit_on_message: bool = False,
@@ -100,15 +106,18 @@ async def run_controller(
             config=agent_config,
         )
 
+    # make sure the session id is set
+    sid = sid or generate_sid(config)
+
     if runtime is None:
-        runtime = await create_runtime(config)
+        runtime = await create_runtime(config, sid=sid)
 
     event_stream = runtime.event_stream
     # restore cli session if enabled
     initial_state = None
     if config.enable_cli_session:
         try:
-            logger.info('Restoring agent state from cli session')
+            logger.info(f'Restoring agent state from cli session {event_stream.sid}')
             initial_state = State.restore_from_session(
                 event_stream.sid, event_stream.file_store
             )
@@ -179,6 +188,15 @@ async def run_controller(
     return state
 
 
+def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
+    """Generate a session id based on the session name and the jwt secret."""
+    session_name = session_name or str(uuid.uuid4())
+    jwt_secret = config.jwt_secret
+
+    hash_str = hashlib.sha256(f'{session_name}{jwt_secret}'.encode('utf-8')).hexdigest()
+    return f'{session_name}_{hash_str[:16]}'
+
+
 if __name__ == '__main__':
     args = parse_arguments()
 
@@ -207,6 +225,10 @@ if __name__ == '__main__':
     # Set default agent
     config.default_agent = args.agent_cls
 
+    # Set session name
+    session_name = args.name
+    sid = generate_sid(config, session_name)
+
     # if max budget per task is not sent on the command line, use the config value
     if args.max_budget_per_task is not None:
         config.max_budget_per_task = args.max_budget_per_task
@@ -217,5 +239,6 @@ if __name__ == '__main__':
         run_controller(
             config=config,
             task_str=task_str,
+            sid=sid,
         )
     )

+ 2 - 2
opendevin/runtime/client/runtime.py

@@ -57,7 +57,7 @@ class EventStreamRuntime(Runtime):
         self.session: aiohttp.ClientSession | None = None
 
         self.instance_id = (
-            sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
+            sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
         )
         # TODO: We can switch to aiodocker when `get_od_sandbox_image` is updated to use aiodocker
         self.docker_client: docker.DockerClient = self._init_docker_client()
@@ -193,7 +193,7 @@ class EventStreamRuntime(Runtime):
         wait=tenacity.wait_exponential(multiplier=2, min=10, max=60),
     )
     async def _wait_until_alive(self):
-        logger.info('Reconnecting session')
+        logger.debug('Getting container logs...')
         container = self.docker_client.containers.get(self.container_name)
         # get logs
         _logs = container.logs(tail=10).decode('utf-8').split('\n')

+ 1 - 1
opendevin/server/session/agent.py

@@ -135,4 +135,4 @@ class AgentSession:
             )
             logger.info(f'Restored agent state from session, sid: {self.sid}')
         except Exception as e:
-            print('Error restoring state', e)
+            logger.info(f'Error restoring state: {e}')

+ 2 - 1
tests/unit/test_arg_parser.py

@@ -14,7 +14,7 @@ usage: pytest [-h] [-d DIRECTORY] [-t TASK] [-f FILE] [-c AGENT_CLS]
               [--eval-output-dir EVAL_OUTPUT_DIR]
               [--eval-n-limit EVAL_N_LIMIT]
               [--eval-num-workers EVAL_NUM_WORKERS] [--eval-note EVAL_NOTE]
-              [-l LLM_CONFIG]
+              [-l LLM_CONFIG] [-n NAME]
 
 Run an agent with a specific task
 
@@ -44,6 +44,7 @@ options:
                         Replace default LLM ([llm] section in config.toml)
                         config with the specified LLM config, e.g. "llama3"
                         for [llm.llama3] section in config.toml
+  -n NAME, --name NAME  Name for the session
 """
 
     actual_lines = captured.out.strip().split('\n')