Kaynağa Gözat

Remove global config from session (#2987)

* Remove global config from session

* Fix double agent
Graham Neubig 1 yıl önce
ebeveyn
işleme
692fe21d60

+ 3 - 1
opendevin/server/listen.py

@@ -40,7 +40,9 @@ from opendevin.events.observation import (
 from opendevin.events.serialization import event_to_dict
 from opendevin.llm import bedrock
 from opendevin.server.auth import get_sid_from_token, sign_token
-from opendevin.server.session import session_manager
+from opendevin.server.session import SessionManager
+
+session_manager = SessionManager(config)
 
 app = FastAPI()
 app.add_middleware(

+ 1 - 3
opendevin/server/session/__init__.py

@@ -1,6 +1,4 @@
 from .manager import SessionManager
 from .session import Session
 
-session_manager = SessionManager()
-
-__all__ = ['Session', 'SessionManager', 'session_manager']
+__all__ = ['Session', 'SessionManager']

+ 21 - 40
opendevin/server/session/agent.py

@@ -4,11 +4,9 @@ from agenthub.codeact_agent.codeact_agent import CodeActAgent
 from opendevin.controller import AgentController
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
+from opendevin.core.config import SandboxConfig
 from opendevin.core.logger import opendevin_logger as logger
-from opendevin.core.schema import ConfigType
 from opendevin.events.stream import EventStream
-from opendevin.llm.llm import LLM
 from opendevin.runtime import DockerSSHBox, get_runtime_cls
 from opendevin.runtime.runtime import Runtime
 from opendevin.runtime.server.runtime import ServerRuntime
@@ -32,7 +30,14 @@ class AgentSession:
         self.sid = sid
         self.event_stream = EventStream(sid)
 
-    async def start(self, start_event: dict):
+    async def start(
+        self,
+        runtime_name: str,
+        sandbox_config: SandboxConfig,
+        agent: Agent,
+        confirmation_mode: bool,
+        max_iterations: int,
+    ):
         """Starts the agent session.
 
         Args:
@@ -42,8 +47,8 @@ class AgentSession:
             raise Exception(
                 'Session already started. You need to close this session and start a new one.'
             )
-        await self._create_runtime()
-        await self._create_controller(start_event)
+        await self._create_runtime(runtime_name, sandbox_config)
+        await self._create_controller(agent, confirmation_mode, max_iterations)
 
     async def close(self):
         if self._closed:
@@ -56,52 +61,28 @@ class AgentSession:
             await self.runtime.close()
         self._closed = True
 
-    async def _create_runtime(self):
+    async def _create_runtime(self, runtime_name: str, sandbox_config: SandboxConfig):
+        """Creates a runtime instance."""
         if self.runtime is not None:
             raise Exception('Runtime already created')
 
-        logger.info(f'Using runtime: {config.runtime}')
-        runtime_cls = get_runtime_cls(config.runtime)
+        logger.info(f'Using runtime: {runtime_name}')
+        runtime_cls = get_runtime_cls(runtime_name)
         self.runtime = runtime_cls(
-            sandbox_config=config.sandbox, event_stream=self.event_stream, sid=self.sid
+            sandbox_config=sandbox_config, event_stream=self.event_stream, sid=self.sid
         )
         await self.runtime.ainit()
 
-    async def _create_controller(self, start_event: dict):
-        """Creates an AgentController instance.
-
-        Args:
-            start_event: The start event data.
-        """
+    async def _create_controller(
+        self, agent: Agent, confirmation_mode: bool, max_iterations: int
+    ):
+        """Creates an AgentController instance."""
         if self.controller is not None:
             raise Exception('Controller already created')
         if self.runtime is None:
             raise Exception('Runtime must be initialized before the agent controller')
-        args = {
-            key: value
-            for key, value in start_event.get('args', {}).items()
-            if value != ''
-        }  # remove empty values, prevent FE from sending empty strings
-        agent_cls = args.get(ConfigType.AGENT, config.default_agent)
-        confirmation_mode = args.get(
-            ConfigType.CONFIRMATION_MODE, config.confirmation_mode
-        )
-        max_iterations = args.get(ConfigType.MAX_ITERATIONS, config.max_iterations)
-
-        # override default LLM config
-        default_llm_config = config.get_llm_config()
-        default_llm_config.model = args.get(
-            ConfigType.LLM_MODEL, default_llm_config.model
-        )
-        default_llm_config.api_key = args.get(
-            ConfigType.LLM_API_KEY, default_llm_config.api_key
-        )
-
-        # TODO: override other LLM config & agent config groups (#2075)
 
-        llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
-        agent = Agent.get_cls(agent_cls)(llm)
-        logger.info(f'Creating agent {agent.name} using LLM {llm}')
+        logger.info(f'Creating agent {agent.name} using LLM {agent.llm.config.model}')
         if isinstance(agent, CodeActAgent):
             if not self.runtime or not (
                 isinstance(self.runtime, ServerRuntime)

+ 4 - 2
opendevin/server/session/manager.py

@@ -4,6 +4,7 @@ from typing import Optional
 
 from fastapi import WebSocket
 
+from opendevin.core.config import AppConfig
 from opendevin.core.logger import opendevin_logger as logger
 
 from .session import Session
@@ -14,13 +15,14 @@ class SessionManager:
     cleanup_interval: int = 300
     session_timeout: int = 600
 
-    def __init__(self):
+    def __init__(self, config: AppConfig):
         asyncio.create_task(self._cleanup_sessions())
+        self.config = config
 
     def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
         if sid in self._sessions:
             asyncio.create_task(self._sessions[sid].close())
-        self._sessions[sid] = Session(sid=sid, ws=ws_conn)
+        self._sessions[sid] = Session(sid=sid, ws=ws_conn, config=self.config)
         return self._sessions[sid]
 
     def get_session(self, sid: str) -> Session | None:

+ 37 - 2
opendevin/server/session/session.py

@@ -3,10 +3,13 @@ import time
 
 from fastapi import WebSocket, WebSocketDisconnect
 
+from opendevin.controller.agent import Agent
+from opendevin.core.config import AppConfig
 from opendevin.core.const.guide_url import TROUBLESHOOTING_URL
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import AgentState
 from opendevin.core.schema.action import ActionType
+from opendevin.core.schema.config import ConfigType
 from opendevin.events.action import Action, ChangeAgentStateAction, NullAction
 from opendevin.events.event import Event, EventSource
 from opendevin.events.observation import (
@@ -16,6 +19,7 @@ from opendevin.events.observation import (
 )
 from opendevin.events.serialization import event_from_dict, event_to_dict
 from opendevin.events.stream import EventStreamSubscriber
+from opendevin.llm.llm import LLM
 
 from .agent import AgentSession
 
@@ -29,7 +33,7 @@ class Session:
     is_alive: bool = True
     agent_session: AgentSession
 
-    def __init__(self, sid: str, ws: WebSocket | None):
+    def __init__(self, sid: str, ws: WebSocket | None, config: AppConfig):
         self.sid = sid
         self.websocket = ws
         self.last_active_ts = int(time.time())
@@ -37,6 +41,7 @@ class Session:
         self.agent_session.event_stream.subscribe(
             EventStreamSubscriber.SERVER, self.on_event
         )
+        self.config = config
 
     async def close(self):
         self.is_alive = False
@@ -67,8 +72,38 @@ class Session:
         self.agent_session.event_stream.add_event(
             AgentStateChangedObservation('', AgentState.LOADING), EventSource.AGENT
         )
+        # Extract the agent-relevant arguments from the request
+        args = {
+            key: value for key, value in data.get('args', {}).items() if value != ''
+        }
+        agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
+        confirmation_mode = args.get(
+            ConfigType.CONFIRMATION_MODE, self.config.confirmation_mode
+        )
+        max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
+        # override default LLM config
+        default_llm_config = self.config.get_llm_config()
+        default_llm_config.model = args.get(
+            ConfigType.LLM_MODEL, default_llm_config.model
+        )
+        default_llm_config.api_key = args.get(
+            ConfigType.LLM_API_KEY, default_llm_config.api_key
+        )
+
+        # TODO: override other LLM config & agent config groups (#2075)
+
+        llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
+        agent = Agent.get_cls(agent_cls)(llm)
+
+        # Create the agent session
         try:
-            await self.agent_session.start(data)
+            await self.agent_session.start(
+                runtime_name=self.config.runtime,
+                sandbox_config=self.config.sandbox,
+                agent=agent,
+                confirmation_mode=confirmation_mode,
+                max_iterations=max_iterations,
+            )
         except Exception as e:
             logger.exception(f'Error creating controller: {e}')
             await self.send_error(