فهرست منبع

Refactor agent to accept agent config (#3430)

* refactor agents to receive their agent config

* add unit test

* fix test

* fix tests
Engel Nyst 1 سال پیش
والد
کامیت
92b1a2da5c

+ 3 - 1
agenthub/browsing_agent/browsing_agent.py

@@ -6,6 +6,7 @@ from browsergym.utils.obs import flatten_axtree_to_str
 from agenthub.browsing_agent.response_parser import BrowsingResponseParser
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.message import Message, TextContent
 from opendevin.events.action import (
@@ -99,13 +100,14 @@ class BrowsingAgent(Agent):
     def __init__(
         self,
         llm: LLM,
+        config: AgentConfig,
     ) -> None:
         """Initializes a new instance of the BrowsingAgent class.
 
         Parameters:
         - llm (LLM): The llm to be used by this agent
         """
-        super().__init__(llm)
+        super().__init__(llm, config)
         # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML.
         # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details
         action_subsets = ['chat', 'bid']

+ 3 - 1
agenthub/codeact_agent/codeact_agent.py

@@ -8,6 +8,7 @@ from agenthub.codeact_agent.prompt import (
 )
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.message import ImageContent, Message, TextContent
 from opendevin.events.action import (
     Action,
@@ -103,13 +104,14 @@ class CodeActAgent(Agent):
     def __init__(
         self,
         llm: LLM,
+        config: AgentConfig,
     ) -> None:
         """Initializes a new instance of the CodeActAgent class.
 
         Parameters:
         - llm (LLM): The llm to be used by this agent
         """
-        super().__init__(llm)
+        super().__init__(llm, config)
         self.reset()
 
     def action_to_str(self, action: Action) -> str:

+ 3 - 1
agenthub/codeact_swe_agent/codeact_swe_agent.py

@@ -7,6 +7,7 @@ from agenthub.codeact_swe_agent.prompt import (
 from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.message import ImageContent, Message, TextContent
 from opendevin.events.action import (
     Action,
@@ -66,13 +67,14 @@ class CodeActSWEAgent(Agent):
     def __init__(
         self,
         llm: LLM,
+        config: AgentConfig,
     ) -> None:
         """Initializes a new instance of the CodeActSWEAgent class.
 
         Parameters:
         - llm (LLM): The llm to be used by this agent
         """
-        super().__init__(llm)
+        super().__init__(llm, config)
         self.reset()
 
     def action_to_str(self, action: Action) -> str:

+ 3 - 2
agenthub/delegator_agent/agent.py

@@ -1,5 +1,6 @@
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.events.action import Action, AgentDelegateAction, AgentFinishAction
 from opendevin.events.observation import AgentDelegateObservation
 from opendevin.llm.llm import LLM
@@ -13,13 +14,13 @@ class DelegatorAgent(Agent):
 
     current_delegate: str = ''
 
-    def __init__(self, llm: LLM):
+    def __init__(self, llm: LLM, config: AgentConfig):
         """Initialize the Delegator Agent with an LLM
 
         Parameters:
         - llm (LLM): The llm to be used by this agent
         """
-        super().__init__(llm)
+        super().__init__(llm, config)
 
     def step(self, state: State) -> Action:
         """Checks to see if current step is completed, returns AgentFinishAction if True.

+ 3 - 2
agenthub/dummy_agent/agent.py

@@ -2,6 +2,7 @@ from typing import TypedDict, Union
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.schema import AgentState
 from opendevin.events.action import (
     Action,
@@ -45,8 +46,8 @@ class DummyAgent(Agent):
     without making any LLM calls.
     """
 
-    def __init__(self, llm: LLM):
-        super().__init__(llm)
+    def __init__(self, llm: LLM, config: AgentConfig):
+        super().__init__(llm, config)
         self.steps: list[ActionObs] = [
             {
                 'action': AddTaskAction(

+ 3 - 2
agenthub/micro/agent.py

@@ -2,6 +2,7 @@ from jinja2 import BaseLoader, Environment
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.message import ImageContent, Message, TextContent
 from opendevin.core.utils import json
 from opendevin.events.action import Action
@@ -54,8 +55,8 @@ class MicroAgent(Agent):
 
         return json.dumps(processed_history, **kwargs)
 
-    def __init__(self, llm: LLM):
-        super().__init__(llm)
+    def __init__(self, llm: LLM, config: AgentConfig):
+        super().__init__(llm, config)
         if 'name' not in self.agent_definition:
             raise ValueError('Agent definition must contain a name')
         self.prompt_template = Environment(loader=BaseLoader).from_string(self.prompt)

+ 3 - 2
agenthub/planner_agent/agent.py

@@ -1,6 +1,7 @@
 from agenthub.planner_agent.response_parser import PlannerResponseParser
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.core.message import ImageContent, Message, TextContent
 from opendevin.events.action import Action, AgentFinishAction
 from opendevin.llm.llm import LLM
@@ -18,13 +19,13 @@ class PlannerAgent(Agent):
     runtime_tools: list[RuntimeTool] = [RuntimeTool.BROWSER]
     response_parser = PlannerResponseParser()
 
-    def __init__(self, llm: LLM):
+    def __init__(self, llm: LLM, config: AgentConfig):
         """Initialize the Planner Agent with an LLM
 
         Parameters:
         - llm (LLM): The llm to be used by this agent
         """
-        super().__init__(llm)
+        super().__init__(llm, config)
 
     def step(self, state: State) -> Action:
         """Checks to see if current step is completed, returns AgentFinishAction if True.

+ 3 - 0
opendevin/controller/agent.py

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Type
 
 if TYPE_CHECKING:
     from opendevin.controller.state.state import State
+    from opendevin.core.config import AgentConfig
     from opendevin.events.action import Action
 from opendevin.core.exceptions import (
     AgentAlreadyRegisteredError,
@@ -29,8 +30,10 @@ class Agent(ABC):
     def __init__(
         self,
         llm: LLM,
+        config: 'AgentConfig',
     ):
         self.llm = llm
+        self.config = config
         self._complete = False
 
     @property

+ 10 - 2
opendevin/controller/agent_controller.py

@@ -5,7 +5,7 @@ from typing import Type
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State, TrafficControlState
 from opendevin.controller.stuck import StuckDetector
-from opendevin.core.config import LLMConfig
+from opendevin.core.config import AgentConfig, LLMConfig
 from opendevin.core.exceptions import (
     LLMMalformedActionError,
     LLMNoActionError,
@@ -52,6 +52,7 @@ class AgentController:
     state: State
     confirmation_mode: bool
     agent_to_llm_config: dict[str, LLMConfig]
+    agent_configs: dict[str, AgentConfig]
     agent_task: asyncio.Task | None = None
     parent: 'AgentController | None' = None
     delegate: 'AgentController | None' = None
@@ -64,6 +65,7 @@ class AgentController:
         max_iterations: int,
         max_budget_per_task: float | None = None,
         agent_to_llm_config: dict[str, LLMConfig] | None = None,
+        agent_configs: dict[str, AgentConfig] | None = None,
         sid: str = 'default',
         confirmation_mode: bool = False,
         initial_state: State | None = None,
@@ -79,6 +81,8 @@ class AgentController:
             max_budget_per_task: The maximum budget (in USD) allowed per task, beyond which the agent will stop.
             agent_to_llm_config: A dictionary mapping agent names to LLM configurations in the case that
                 we delegate to a different agent.
+            agent_configs: A dictionary mapping agent names to agent configurations in the case that
+                we delegate to a different agent.
             sid: The session ID of the agent.
             initial_state: The initial state of the controller.
             is_delegate: Whether this controller is a delegate.
@@ -103,6 +107,7 @@ class AgentController:
         )
         self.max_budget_per_task = max_budget_per_task
         self.agent_to_llm_config = agent_to_llm_config if agent_to_llm_config else {}
+        self.agent_configs = agent_configs if agent_configs else {}
 
         # stuck helper
         self._stuck_detector = StuckDetector(self.state)
@@ -256,9 +261,10 @@ class AgentController:
 
     async def start_delegate(self, action: AgentDelegateAction):
         agent_cls: Type[Agent] = Agent.get_cls(action.agent)
+        agent_config = self.agent_configs.get(action.agent, self.agent.config)
         llm_config = self.agent_to_llm_config.get(action.agent, self.agent.llm.config)
         llm = LLM(config=llm_config)
-        delegate_agent = agent_cls(llm=llm)
+        delegate_agent = agent_cls(llm=llm, config=agent_config)
         state = State(
             inputs=action.inputs or {},
             local_iteration=0,
@@ -278,8 +284,10 @@ class AgentController:
             max_iterations=self.state.max_iterations,
             max_budget_per_task=self.max_budget_per_task,
             agent_to_llm_config=self.agent_to_llm_config,
+            agent_configs=self.agent_configs,
             initial_state=state,
             is_delegate=True,
+            headless_mode=self.headless_mode,
         )
         await self.delegate.set_agent_state_to(AgentState.RUNNING)
 

+ 3 - 0
opendevin/core/config.py

@@ -330,6 +330,9 @@ class AppConfig(metaclass=Singleton):
         llm_config_name = agent_config.llm_config
         return self.get_llm_config(llm_config_name)
 
+    def get_agent_configs(self) -> dict[str, AgentConfig]:
+        return self.agents
+
     def __post_init__(self):
         """Post-initialization hook, called when the instance is created with only default values."""
         AppConfig.defaults_dict = self.defaults_to_dict()

+ 4 - 1
opendevin/core/main.py

@@ -93,8 +93,11 @@ async def run_controller(
     # Create the agent
     if agent is None:
         agent_cls: Type[Agent] = Agent.get_cls(config.default_agent)
+        agent_config = config.get_agent_config(config.default_agent)
+        llm_config = config.get_llm_config_from_agent(config.default_agent)
         agent = agent_cls(
-            llm=LLM(config=config.get_llm_config_from_agent(config.default_agent))
+            llm=LLM(config=llm_config),
+            config=agent_config,
         )
 
     if runtime is None:

+ 6 - 1
opendevin/server/session/agent.py

@@ -1,7 +1,7 @@
 from opendevin.controller import AgentController
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import AppConfig, LLMConfig
+from opendevin.core.config import AgentConfig, AppConfig, LLMConfig
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.events.stream import EventStream
 from opendevin.runtime import get_runtime_cls
@@ -39,6 +39,7 @@ class AgentSession:
         max_iterations: int,
         max_budget_per_task: float | None = None,
         agent_to_llm_config: dict[str, LLMConfig] | None = None,
+        agent_configs: dict[str, AgentConfig] | None = None,
     ):
         """Starts the agent session.
 
@@ -57,6 +58,7 @@ class AgentSession:
             max_iterations,
             max_budget_per_task=max_budget_per_task,
             agent_to_llm_config=agent_to_llm_config,
+            agent_configs=agent_configs,
         )
 
     async def close(self):
@@ -102,6 +104,7 @@ class AgentSession:
         max_iterations: int,
         max_budget_per_task: float | None = None,
         agent_to_llm_config: dict[str, LLMConfig] | None = None,
+        agent_configs: dict[str, AgentConfig] | None = None,
     ):
         """Creates an AgentController instance."""
         if self.controller is not None:
@@ -109,6 +112,7 @@ class AgentSession:
         if self.runtime is None:
             raise Exception('Runtime must be initialized before the agent controller')
 
+        logger.info(f'Agents: {agent_configs}')
         logger.info(f'Creating agent {agent.name} using LLM {agent.llm.config.model}')
 
         self.controller = AgentController(
@@ -118,6 +122,7 @@ class AgentSession:
             max_iterations=int(max_iterations),
             max_budget_per_task=max_budget_per_task,
             agent_to_llm_config=agent_to_llm_config,
+            agent_configs=agent_configs,
             confirmation_mode=confirmation_mode,
             # AgentSession is designed to communicate with the frontend, so we don't want to
             # run the agent in headless mode.

+ 3 - 1
opendevin/server/session/session.py

@@ -102,7 +102,8 @@ class Session:
         # TODO: override other LLM config & agent config groups (#2075)
 
         llm = LLM(config=self.config.get_llm_config_from_agent(agent_cls))
-        agent = Agent.get_cls(agent_cls)(llm)
+        agent_config = self.config.get_agent_config(agent_cls)
+        agent = Agent.get_cls(agent_cls)(llm, agent_config)
 
         # Create the agent session
         try:
@@ -113,6 +114,7 @@ class Session:
                 max_iterations=max_iterations,
                 max_budget_per_task=self.config.max_budget_per_task,
                 agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
+                agent_configs=self.config.get_agent_configs(),
             )
         except Exception as e:
             logger.exception(f'Error creating controller: {e}')

+ 2 - 2
tests/unit/test_codeact_agent.py

@@ -3,7 +3,7 @@ from unittest.mock import Mock
 import pytest
 
 from agenthub.codeact_agent.codeact_agent import CodeActAgent
-from opendevin.core.config import LLMConfig
+from opendevin.core.config import AgentConfig, LLMConfig
 from opendevin.core.message import TextContent
 from opendevin.events.observation.commands import (
     CmdOutputObservation,
@@ -16,7 +16,7 @@ from opendevin.llm.llm import LLM
 
 @pytest.fixture
 def agent() -> CodeActAgent:
-    agent = CodeActAgent(llm=LLM(LLMConfig()))
+    agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
     agent.llm = Mock()
     agent.llm.config = Mock()
     agent.llm.config.max_message_chars = 100

+ 24 - 0
tests/unit/test_config.py

@@ -538,3 +538,27 @@ embedding_model="openai"
     llm_config = get_llm_config_arg('gpt3', temp_toml_file)
     assert llm_config.model == 'gpt-3.5-turbo'
     assert llm_config.embedding_model == 'openai'
+
+
+def test_get_agent_configs(default_config, temp_toml_file):
+    temp_toml = """
+[core]
+max_iterations = 100
+max_budget_per_task = 4.0
+
+[agent.CodeActAgent]
+memory_enabled = true
+
+[agent.PlannerAgent]
+memory_max_threads = 10
+"""
+
+    with open(temp_toml_file, 'w') as f:
+        f.write(temp_toml)
+
+    load_from_toml(default_config, temp_toml_file)
+
+    codeact_config = default_config.get_agent_configs().get('CodeActAgent')
+    assert codeact_config.memory_enabled is True
+    planner_config = default_config.get_agent_configs().get('PlannerAgent')
+    assert planner_config.memory_max_threads == 10

+ 17 - 4
tests/unit/test_micro_agents.py

@@ -9,6 +9,7 @@ from pytest import TempPathFactory
 from agenthub.micro.registry import all_microagents
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
+from opendevin.core.config import AgentConfig
 from opendevin.events import EventSource
 from opendevin.events.action import MessageAction
 from opendevin.events.stream import EventStream
@@ -31,6 +32,14 @@ def event_stream(temp_dir):
     event_stream.clear()
 
 
+@pytest.fixture
+def agent_configs():
+    return {
+        'CoderAgent': AgentConfig(memory_enabled=True),
+        'PlannerAgent': AgentConfig(memory_enabled=True),
+    }
+
+
 def test_all_agents_are_loaded():
     assert all_microagents is not None
     assert len(all_microagents) > 1
@@ -48,13 +57,15 @@ def test_all_agents_are_loaded():
     assert agent_names == set(all_microagents.keys())
 
 
-def test_coder_agent_with_summary(event_stream: EventStream):
+def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict):
     """Coder agent should render code summary as part of prompt"""
     mock_llm = MagicMock()
     content = json.dumps({'action': 'finish', 'args': {}})
     mock_llm.completion.return_value = {'choices': [{'message': {'content': content}}]}
 
-    coder_agent = Agent.get_cls('CoderAgent')(llm=mock_llm)
+    coder_agent = Agent.get_cls('CoderAgent')(
+        llm=mock_llm, config=agent_configs['CoderAgent']
+    )
     assert coder_agent is not None
 
     task = 'This is a dummy task'
@@ -74,7 +85,7 @@ def test_coder_agent_with_summary(event_stream: EventStream):
     assert summary in prompt
 
 
-def test_coder_agent_without_summary(event_stream: EventStream):
+def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: dict):
     """When there's no codebase_summary available, there shouldn't be any prompt
     about 'code summary'
     """
@@ -82,7 +93,9 @@ def test_coder_agent_without_summary(event_stream: EventStream):
     content = json.dumps({'action': 'finish', 'args': {}})
     mock_llm.completion.return_value = {'choices': [{'message': {'content': content}}]}
 
-    coder_agent = Agent.get_cls('CoderAgent')(llm=mock_llm)
+    coder_agent = Agent.get_cls('CoderAgent')(
+        llm=mock_llm, config=agent_configs['CoderAgent']
+    )
     assert coder_agent is not None
 
     task = 'This is a dummy task'