Explorar el Código

Refactor LLM config (#2953)

* Add max_message_chars to LLM

* Refactor LLM config

* Fix tests

* Made some functions class functions

* Fix regression

* Fixed comments
Graham Neubig hace 1 año
padre
commit
c897791024

+ 62 - 62
agenthub/codeact_agent/codeact_agent.py

@@ -8,7 +8,6 @@ from agenthub.codeact_agent.prompt import (
 )
 )
 from opendevin.controller.agent import Agent
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.events.action import (
 from opendevin.events.action import (
     Action,
     Action,
     AgentDelegateAction,
     AgentDelegateAction,
@@ -22,6 +21,7 @@ from opendevin.events.observation import (
     CmdOutputObservation,
     CmdOutputObservation,
     IPythonRunCellObservation,
     IPythonRunCellObservation,
 )
 )
+from opendevin.events.observation.observation import Observation
 from opendevin.events.serialization.event import truncate_content
 from opendevin.events.serialization.event import truncate_content
 from opendevin.llm.llm import LLM
 from opendevin.llm.llm import LLM
 from opendevin.runtime.plugins import (
 from opendevin.runtime.plugins import (
@@ -34,62 +34,6 @@ from opendevin.runtime.tools import RuntimeTool
 ENABLE_GITHUB = True
 ENABLE_GITHUB = True
 
 
 
 
-def action_to_str(action: Action) -> str:
-    if isinstance(action, CmdRunAction):
-        return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
-    elif isinstance(action, IPythonRunCellAction):
-        return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
-    elif isinstance(action, AgentDelegateAction):
-        return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
-    elif isinstance(action, MessageAction):
-        return action.content
-    return ''
-
-
-def get_action_message(action: Action) -> dict[str, str] | None:
-    if (
-        isinstance(action, AgentDelegateAction)
-        or isinstance(action, CmdRunAction)
-        or isinstance(action, IPythonRunCellAction)
-        or isinstance(action, MessageAction)
-    ):
-        return {
-            'role': 'user' if action.source == 'user' else 'assistant',
-            'content': action_to_str(action),
-        }
-    return None
-
-
-def get_observation_message(obs) -> dict[str, str] | None:
-    max_message_chars = config.get_llm_config_from_agent(
-        'CodeActAgent'
-    ).max_message_chars
-    if isinstance(obs, CmdOutputObservation):
-        content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
-        content += (
-            f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
-        )
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, IPythonRunCellObservation):
-        content = 'OBSERVATION:\n' + obs.content
-        # replace base64 images with a placeholder
-        splitted = content.split('\n')
-        for i, line in enumerate(splitted):
-            if '![image](data:image/png;base64,' in line:
-                splitted[i] = (
-                    '![image](data:image/png;base64, ...) already displayed to user'
-                )
-        content = '\n'.join(splitted)
-        content = truncate_content(content, max_message_chars)
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, AgentDelegateObservation):
-        content = 'OBSERVATION:\n' + truncate_content(
-            str(obs.outputs), max_message_chars
-        )
-        return {'role': 'user', 'content': content}
-    return None
-
-
 # FIXME: We can tweak these two settings to create MicroAgents specialized toward different area
 # FIXME: We can tweak these two settings to create MicroAgents specialized toward different area
 def get_system_message() -> str:
 def get_system_message() -> str:
     if ENABLE_GITHUB:
     if ENABLE_GITHUB:
@@ -166,6 +110,61 @@ class CodeActAgent(Agent):
         super().__init__(llm)
         super().__init__(llm)
         self.reset()
         self.reset()
 
 
+    def action_to_str(self, action: Action) -> str:
+        if isinstance(action, CmdRunAction):
+            return (
+                f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
+            )
+        elif isinstance(action, IPythonRunCellAction):
+            return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
+        elif isinstance(action, AgentDelegateAction):
+            return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
+        elif isinstance(action, MessageAction):
+            return action.content
+        return ''
+
+    def get_action_message(self, action: Action) -> dict[str, str] | None:
+        if (
+            isinstance(action, AgentDelegateAction)
+            or isinstance(action, CmdRunAction)
+            or isinstance(action, IPythonRunCellAction)
+            or isinstance(action, MessageAction)
+        ):
+            return {
+                'role': 'user' if action.source == 'user' else 'assistant',
+                'content': self.action_to_str(action),
+            }
+        return None
+
+    def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
+        max_message_chars = self.llm.config.max_message_chars
+        if isinstance(obs, CmdOutputObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                obs.content, max_message_chars
+            )
+            content += (
+                f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
+            )
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, IPythonRunCellObservation):
+            content = 'OBSERVATION:\n' + obs.content
+            # replace base64 images with a placeholder
+            splitted = content.split('\n')
+            for i, line in enumerate(splitted):
+                if '![image](data:image/png;base64,' in line:
+                    splitted[i] = (
+                        '![image](data:image/png;base64, ...) already displayed to user'
+                    )
+            content = '\n'.join(splitted)
+            content = truncate_content(content, max_message_chars)
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, AgentDelegateObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                str(obs.outputs), max_message_chars
+            )
+            return {'role': 'user', 'content': content}
+        return None
+
     def reset(self) -> None:
     def reset(self) -> None:
         """Resets the CodeAct Agent."""
         """Resets the CodeAct Agent."""
         super().reset()
         super().reset()
@@ -211,11 +210,12 @@ class CodeActAgent(Agent):
 
 
         for event in state.history.get_events():
         for event in state.history.get_events():
             # create a regular message from an event
             # create a regular message from an event
-            message = (
-                get_action_message(event)
-                if isinstance(event, Action)
-                else get_observation_message(event)
-            )
+            if isinstance(event, Action):
+                message = self.get_action_message(event)
+            elif isinstance(event, Observation):
+                message = self.get_observation_message(event)
+            else:
+                raise ValueError(f'Unknown event type: {type(event)}')
 
 
             # add regular message
             # add regular message
             if message:
             if message:

+ 54 - 54
agenthub/codeact_swe_agent/codeact_swe_agent.py

@@ -7,7 +7,6 @@ from agenthub.codeact_swe_agent.prompt import (
 from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
 from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
 from opendevin.controller.agent import Agent
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.events.action import (
 from opendevin.events.action import (
     Action,
     Action,
     AgentFinishAction,
     AgentFinishAction,
@@ -19,6 +18,7 @@ from opendevin.events.observation import (
     CmdOutputObservation,
     CmdOutputObservation,
     IPythonRunCellObservation,
     IPythonRunCellObservation,
 )
 )
+from opendevin.events.observation.observation import Observation
 from opendevin.events.serialization.event import truncate_content
 from opendevin.events.serialization.event import truncate_content
 from opendevin.llm.llm import LLM
 from opendevin.llm.llm import LLM
 from opendevin.runtime.plugins import (
 from opendevin.runtime.plugins import (
@@ -29,54 +29,6 @@ from opendevin.runtime.plugins import (
 from opendevin.runtime.tools import RuntimeTool
 from opendevin.runtime.tools import RuntimeTool
 
 
 
 
-def action_to_str(action: Action) -> str:
-    if isinstance(action, CmdRunAction):
-        return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
-    elif isinstance(action, IPythonRunCellAction):
-        return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
-    elif isinstance(action, MessageAction):
-        return action.content
-    return ''
-
-
-def get_action_message(action: Action) -> dict[str, str] | None:
-    if (
-        isinstance(action, CmdRunAction)
-        or isinstance(action, IPythonRunCellAction)
-        or isinstance(action, MessageAction)
-    ):
-        return {
-            'role': 'user' if action.source == 'user' else 'assistant',
-            'content': action_to_str(action),
-        }
-    return None
-
-
-def get_observation_message(obs) -> dict[str, str] | None:
-    max_message_chars = config.get_llm_config_from_agent(
-        'CodeActSWEAgent'
-    ).max_message_chars
-    if isinstance(obs, CmdOutputObservation):
-        content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
-        content += (
-            f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
-        )
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, IPythonRunCellObservation):
-        content = 'OBSERVATION:\n' + obs.content
-        # replace base64 images with a placeholder
-        splitted = content.split('\n')
-        for i, line in enumerate(splitted):
-            if '![image](data:image/png;base64,' in line:
-                splitted[i] = (
-                    '![image](data:image/png;base64, ...) already displayed to user'
-                )
-        content = '\n'.join(splitted)
-        content = truncate_content(content, max_message_chars)
-        return {'role': 'user', 'content': content}
-    return None
-
-
 def get_system_message() -> str:
 def get_system_message() -> str:
     return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}'
     return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}'
 
 
@@ -121,6 +73,53 @@ class CodeActSWEAgent(Agent):
         super().__init__(llm)
         super().__init__(llm)
         self.reset()
         self.reset()
 
 
+    def action_to_str(self, action: Action) -> str:
+        if isinstance(action, CmdRunAction):
+            return (
+                f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
+            )
+        elif isinstance(action, IPythonRunCellAction):
+            return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
+        elif isinstance(action, MessageAction):
+            return action.content
+        return ''
+
+    def get_action_message(self, action: Action) -> dict[str, str] | None:
+        if (
+            isinstance(action, CmdRunAction)
+            or isinstance(action, IPythonRunCellAction)
+            or isinstance(action, MessageAction)
+        ):
+            return {
+                'role': 'user' if action.source == 'user' else 'assistant',
+                'content': self.action_to_str(action),
+            }
+        return None
+
+    def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
+        max_message_chars = self.llm.config.max_message_chars
+        if isinstance(obs, CmdOutputObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                obs.content, max_message_chars
+            )
+            content += (
+                f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
+            )
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, IPythonRunCellObservation):
+            content = 'OBSERVATION:\n' + obs.content
+            # replace base64 images with a placeholder
+            splitted = content.split('\n')
+            for i, line in enumerate(splitted):
+                if '![image](data:image/png;base64,' in line:
+                    splitted[i] = (
+                        '![image](data:image/png;base64, ...) already displayed to user'
+                    )
+            content = '\n'.join(splitted)
+            content = truncate_content(content, max_message_chars)
+            return {'role': 'user', 'content': content}
+        return None
+
     def reset(self) -> None:
     def reset(self) -> None:
         """Resets the CodeAct Agent."""
         """Resets the CodeAct Agent."""
         super().reset()
         super().reset()
@@ -165,11 +164,12 @@ class CodeActSWEAgent(Agent):
 
 
         for event in state.history.get_events():
         for event in state.history.get_events():
             # create a regular message from an event
             # create a regular message from an event
-            message = (
-                get_action_message(event)
-                if isinstance(event, Action)
-                else get_observation_message(event)
-            )
+            if isinstance(event, Action):
+                message = self.get_action_message(event)
+            elif isinstance(event, Observation):
+                message = self.get_observation_message(event)
+            else:
+                raise ValueError(f'Unknown event type: {type(event)}')
 
 
             # add regular message
             # add regular message
             if message:
             if message:

+ 23 - 23
agenthub/micro/agent.py

@@ -2,7 +2,6 @@ from jinja2 import BaseLoader, Environment
 
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.core.utils import json
 from opendevin.core.utils import json
 from opendevin.events.action import Action
 from opendevin.events.action import Action
 from opendevin.events.serialization.action import action_from_dict
 from opendevin.events.serialization.action import action_from_dict
@@ -27,32 +26,33 @@ def to_json(obj, **kwargs):
     return json.dumps(obj, **kwargs)
     return json.dumps(obj, **kwargs)
 
 
 
 
-def history_to_json(history: ShortTermHistory, max_events=20, **kwargs):
-    """Serialize and simplify history to str format"""
-    # TODO: get agent specific llm config
-    llm_config = config.get_llm_config()
-    max_message_chars = llm_config.max_message_chars
-
-    processed_history = []
-    event_count = 0
-
-    for event in history.get_events(reverse=True):
-        if event_count >= max_events:
-            break
-        processed_history.append(event_to_memory(event, max_message_chars))
-        event_count += 1
-
-    # history is in reverse order, let's fix it
-    processed_history.reverse()
-
-    return json.dumps(processed_history, **kwargs)
-
-
 class MicroAgent(Agent):
 class MicroAgent(Agent):
     VERSION = '1.0'
     VERSION = '1.0'
     prompt = ''
     prompt = ''
     agent_definition: dict = {}
     agent_definition: dict = {}
 
 
+    def history_to_json(
+        self, history: ShortTermHistory, max_events: int = 20, **kwargs
+    ):
+        """
+        Serialize and simplify history to str format
+        """
+        processed_history = []
+        event_count = 0
+
+        for event in history.get_events(reverse=True):
+            if event_count >= max_events:
+                break
+            processed_history.append(
+                event_to_memory(event, self.llm.config.max_message_chars)
+            )
+            event_count += 1
+
+        # history is in reverse order, let's fix it
+        processed_history.reverse()
+
+        return json.dumps(processed_history, **kwargs)
+
     def __init__(self, llm: LLM):
     def __init__(self, llm: LLM):
         super().__init__(llm)
         super().__init__(llm)
         if 'name' not in self.agent_definition:
         if 'name' not in self.agent_definition:
@@ -66,7 +66,7 @@ class MicroAgent(Agent):
             state=state,
             state=state,
             instructions=instructions,
             instructions=instructions,
             to_json=to_json,
             to_json=to_json,
-            history_to_json=history_to_json,
+            history_to_json=self.history_to_json,
             delegates=self.delegates,
             delegates=self.delegates,
             latest_user_message=state.get_current_user_intent(),
             latest_user_message=state.get_current_user_intent(),
         )
         )

+ 12 - 12
agenthub/monologue_agent/agent.py

@@ -83,10 +83,7 @@ class MonologueAgent(Agent):
         self._add_initial_thoughts(task)
         self._add_initial_thoughts(task)
         self._initialized = True
         self._initialized = True
 
 
-    def _add_initial_thoughts(self, task):
-        max_message_chars = config.get_llm_config_from_agent(
-            'MonologueAgent'
-        ).max_message_chars
+    def _add_initial_thoughts(self, task: str):
         previous_action = ''
         previous_action = ''
         for thought in INITIAL_THOUGHTS:
         for thought in INITIAL_THOUGHTS:
             thought = thought.replace('$TASK', task)
             thought = thought.replace('$TASK', task)
@@ -103,7 +100,7 @@ class MonologueAgent(Agent):
                         content=thought, url='', screenshot=''
                         content=thought, url='', screenshot=''
                     )
                     )
                 self.initial_thoughts.append(
                 self.initial_thoughts.append(
-                    event_to_memory(observation, max_message_chars)
+                    event_to_memory(observation, self.llm.config.max_message_chars)
                 )
                 )
                 previous_action = ''
                 previous_action = ''
             else:
             else:
@@ -127,7 +124,9 @@ class MonologueAgent(Agent):
                     previous_action = ActionType.BROWSE
                     previous_action = ActionType.BROWSE
                 else:
                 else:
                     action = MessageAction(thought)
                     action = MessageAction(thought)
-                self.initial_thoughts.append(event_to_memory(action, max_message_chars))
+                self.initial_thoughts.append(
+                    event_to_memory(action, self.llm.config.max_message_chars)
+                )
 
 
     def step(self, state: State) -> Action:
     def step(self, state: State) -> Action:
         """Modifies the current state by adding the most recent actions and observations, then prompts the model to think about it's next action to take using monologue, memory, and hint.
         """Modifies the current state by adding the most recent actions and observations, then prompts the model to think about it's next action to take using monologue, memory, and hint.
@@ -138,9 +137,6 @@ class MonologueAgent(Agent):
         Returns:
         Returns:
         - Action: The next action to take based on LLM response
         - Action: The next action to take based on LLM response
         """
         """
-        max_message_chars = config.get_llm_config_from_agent(
-            'MonologueAgent'
-        ).max_message_chars
         goal = state.get_current_user_intent()
         goal = state.get_current_user_intent()
         self._initialize(goal)
         self._initialize(goal)
 
 
@@ -148,7 +144,9 @@ class MonologueAgent(Agent):
 
 
         # add the events from state.history
         # add the events from state.history
         for event in state.history.get_events():
         for event in state.history.get_events():
-            recent_events.append(event_to_memory(event, max_message_chars))
+            recent_events.append(
+                event_to_memory(event, self.llm.config.max_message_chars)
+            )
 
 
         # add the last messages to long term memory
         # add the last messages to long term memory
         if self.memory is not None:
         if self.memory is not None:
@@ -158,10 +156,12 @@ class MonologueAgent(Agent):
             # this should still work
             # this should still work
             # we will need to do this differently: find out if there really is an action or an observation in this step
             # we will need to do this differently: find out if there really is an action or an observation in this step
             if last_action:
             if last_action:
-                self.memory.add_event(event_to_memory(last_action, max_message_chars))
+                self.memory.add_event(
+                    event_to_memory(last_action, self.llm.config.max_message_chars)
+                )
             if last_observation:
             if last_observation:
                 self.memory.add_event(
                 self.memory.add_event(
-                    event_to_memory(last_observation, max_message_chars)
+                    event_to_memory(last_observation, self.llm.config.max_message_chars)
                 )
                 )
 
 
         # the action prompt with initial thoughts and recent events
         # the action prompt with initial thoughts and recent events

+ 1 - 1
agenthub/planner_agent/agent.py

@@ -42,7 +42,7 @@ class PlannerAgent(Agent):
             'abandoned',
             'abandoned',
         ]:
         ]:
             return AgentFinishAction()
             return AgentFinishAction()
-        prompt = get_prompt(state)
+        prompt = get_prompt(state, self.llm.config.max_message_chars)
         messages = [{'content': prompt, 'role': 'user'}]
         messages = [{'content': prompt, 'role': 'user'}]
         resp = self.llm.completion(messages=messages)
         resp = self.llm.completion(messages=messages)
         return self.response_parser.parse(resp)
         return self.response_parser.parse(resp)

+ 2 - 6
agenthub/planner_agent/prompt.py

@@ -1,5 +1,4 @@
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import ActionType
 from opendevin.core.schema import ActionType
 from opendevin.core.utils import json
 from opendevin.core.utils import json
@@ -116,8 +115,9 @@ def get_hint(latest_action_id: str) -> str:
     return hints.get(latest_action_id, '')
     return hints.get(latest_action_id, '')
 
 
 
 
-def get_prompt(state: State) -> str:
+def get_prompt(state: State, max_message_chars: int) -> str:
     """Gets the prompt for the planner agent.
     """Gets the prompt for the planner agent.
+
     Formatted with the most recent action-observation pairs, current task, and hint based on last action
     Formatted with the most recent action-observation pairs, current task, and hint based on last action
 
 
     Parameters:
     Parameters:
@@ -126,10 +126,6 @@ def get_prompt(state: State) -> str:
     Returns:
     Returns:
     - str: The formatted string prompt with historical values
     - str: The formatted string prompt with historical values
     """
     """
-    max_message_chars = config.get_llm_config_from_agent(
-        'PlannerAgent'
-    ).max_message_chars
-
     # the plan
     # the plan
     plan_str = json.dumps(state.root_task.to_dict(), indent=2)
     plan_str = json.dumps(state.root_task.to_dict(), indent=2)
 
 

+ 1 - 1
opendevin/controller/agent_controller.py

@@ -248,7 +248,7 @@ class AgentController:
     async def start_delegate(self, action: AgentDelegateAction):
     async def start_delegate(self, action: AgentDelegateAction):
         agent_cls: Type[Agent] = Agent.get_cls(action.agent)
         agent_cls: Type[Agent] = Agent.get_cls(action.agent)
         llm_config = config.get_llm_config_from_agent(action.agent)
         llm_config = config.get_llm_config_from_agent(action.agent)
-        llm = LLM(llm_config=llm_config)
+        llm = LLM(config=llm_config)
         delegate_agent = agent_cls(llm=llm)
         delegate_agent = agent_cls(llm=llm)
         state = State(
         state = State(
             inputs=action.inputs or {},
             inputs=action.inputs or {},

+ 1 - 1
opendevin/core/config.py

@@ -198,7 +198,7 @@ class AppConfig(metaclass=Singleton):
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
     """
     """
 
 
-    llms: dict = field(default_factory=dict)
+    llms: dict[str, LLMConfig] = field(default_factory=dict)
     agents: dict = field(default_factory=dict)
     agents: dict = field(default_factory=dict)
     default_agent: str = 'CodeActAgent'
     default_agent: str = 'CodeActAgent'
     sandbox: SandboxConfig = field(default_factory=SandboxConfig)
     sandbox: SandboxConfig = field(default_factory=SandboxConfig)

+ 2 - 2
opendevin/core/main.py

@@ -52,7 +52,7 @@ async def run_agent_controller(
     """
     """
     # Logging
     # Logging
     logger.info(
     logger.info(
-        f'Running agent {agent.name}, model {agent.llm.model_name}, with task: "{task_str}"'
+        f'Running agent {agent.name}, model {agent.llm.config.model}, with task: "{task_str}"'
     )
     )
 
 
     # set up the event stream
     # set up the event stream
@@ -163,7 +163,7 @@ if __name__ == '__main__':
         if llm_config is None:
         if llm_config is None:
             raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
             raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
         config.set_llm_config(llm_config)
         config.set_llm_config(llm_config)
-    llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
+    llm = LLM(config=config.get_llm_config_from_agent(args.agent_cls))
 
 
     # Create the agent
     # Create the agent
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)

+ 62 - 142
opendevin/llm/llm.py

@@ -1,6 +1,9 @@
+import copy
 import warnings
 import warnings
 from functools import partial
 from functools import partial
 
 
+from opendevin.core.config import LLMConfig
+
 with warnings.catch_warnings():
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
     warnings.simplefilter('ignore')
     import litellm
     import litellm
@@ -21,7 +24,6 @@ from tenacity import (
     wait_random_exponential,
     wait_random_exponential,
 )
 )
 
 
-from opendevin.core.config import config
 from opendevin.core.logger import llm_prompt_logger, llm_response_logger
 from opendevin.core.logger import llm_prompt_logger, llm_response_logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.metrics import Metrics
 from opendevin.core.metrics import Metrics
@@ -35,155 +37,71 @@ class LLM:
     """The LLM class represents a Language Model instance.
     """The LLM class represents a Language Model instance.
 
 
     Attributes:
     Attributes:
-        model_name (str): The name of the language model.
-        api_key (str): The API key for accessing the language model.
-        base_url (str): The base URL for the language model API.
-        api_version (str): The version of the API to use.
-        max_input_tokens (int): The maximum number of tokens to send to the LLM per task.
-        max_output_tokens (int): The maximum number of tokens to receive from the LLM per task.
-        llm_timeout (int): The maximum time to wait for a response in seconds.
-        custom_llm_provider (str): A custom LLM provider.
+        config: an LLMConfig object specifying the configuration of the LLM.
     """
     """
 
 
     def __init__(
     def __init__(
         self,
         self,
-        model=None,
-        api_key=None,
-        base_url=None,
-        api_version=None,
-        num_retries=None,
-        retry_min_wait=None,
-        retry_max_wait=None,
-        llm_timeout=None,
-        llm_temperature=None,
-        llm_top_p=None,
-        custom_llm_provider=None,
-        max_input_tokens=None,
-        max_output_tokens=None,
-        llm_config=None,
-        metrics=None,
-        cost_metric_supported=True,
-        input_cost_per_token=None,
-        output_cost_per_token=None,
+        config: LLMConfig,
+        metrics: Metrics | None = None,
     ):
     ):
         """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
         """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
 
 
         Passing simple parameters always overrides config.
         Passing simple parameters always overrides config.
 
 
         Args:
         Args:
-            model (str, optional): The name of the language model. Defaults to LLM_MODEL.
-            api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
-            base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
-            api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
-            num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
-            retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
-            retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
-            max_input_tokens (int, optional): The maximum number of tokens to send to the LLM per task. Defaults to LLM_MAX_INPUT_TOKENS.
-            max_output_tokens (int, optional): The maximum number of tokens to receive from the LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS.
-            custom_llm_provider (str, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER.
-            llm_timeout (int, optional): The maximum time to wait for a response in seconds. Defaults to LLM_TIMEOUT.
-            llm_temperature (float, optional): The temperature for LLM sampling. Defaults to LLM_TEMPERATURE.
-            metrics (Metrics, optional): The metrics object to use. Defaults to None.
-            cost_metric_supported (bool, optional): Whether the cost metric is supported. Defaults to True.
-            input_cost_per_token (float, optional): The cost per input token.
-            output_cost_per_token (float, optional): The cost per output token.
+            config: The LLM configuration
         """
         """
-        if llm_config is None:
-            llm_config = config.get_llm_config()
-        model = model if model is not None else llm_config.model
-        api_key = api_key if api_key is not None else llm_config.api_key
-        base_url = base_url if base_url is not None else llm_config.base_url
-        api_version = api_version if api_version is not None else llm_config.api_version
-        num_retries = num_retries if num_retries is not None else llm_config.num_retries
-        retry_min_wait = (
-            retry_min_wait if retry_min_wait is not None else llm_config.retry_min_wait
-        )
-        retry_max_wait = (
-            retry_max_wait if retry_max_wait is not None else llm_config.retry_max_wait
-        )
-        llm_timeout = llm_timeout if llm_timeout is not None else llm_config.timeout
-        llm_temperature = (
-            llm_temperature if llm_temperature is not None else llm_config.temperature
-        )
-        llm_top_p = llm_top_p if llm_top_p is not None else llm_config.top_p
-        custom_llm_provider = (
-            custom_llm_provider
-            if custom_llm_provider is not None
-            else llm_config.custom_llm_provider
-        )
-        max_input_tokens = (
-            max_input_tokens
-            if max_input_tokens is not None
-            else llm_config.max_input_tokens
-        )
-        max_output_tokens = (
-            max_output_tokens
-            if max_output_tokens is not None
-            else llm_config.max_output_tokens
-        )
-        input_cost_per_token = (
-            input_cost_per_token
-            if input_cost_per_token is not None
-            else llm_config.input_cost_per_token
-        )
-        output_cost_per_token = (
-            output_cost_per_token
-            if output_cost_per_token is not None
-            else llm_config.output_cost_per_token
-        )
-        metrics = metrics if metrics is not None else Metrics()
-
-        logger.info(f'Initializing LLM with model: {model}')
-        self.model_name = model
-        self.api_key = api_key
-        self.base_url = base_url
-        self.api_version = api_version
-        self.max_input_tokens = max_input_tokens
-        self.max_output_tokens = max_output_tokens
-        self.input_cost_per_token = input_cost_per_token
-        self.output_cost_per_token = output_cost_per_token
-        self.llm_timeout = llm_timeout
-        self.custom_llm_provider = custom_llm_provider
-        self.metrics = metrics
-        self.cost_metric_supported = cost_metric_supported
+
+        self.config = copy.deepcopy(config)
+        self.metrics = metrics if metrics is not None else Metrics()
+        self.cost_metric_supported = True
 
 
         # litellm actually uses base Exception here for unknown model
         # litellm actually uses base Exception here for unknown model
         self.model_info = None
         self.model_info = None
         try:
         try:
-            if not self.model_name.startswith('openrouter'):
-                self.model_info = litellm.get_model_info(self.model_name.split(':')[0])
+            if not config.model.startswith('openrouter'):
+                self.model_info = litellm.get_model_info(config.model.split(':')[0])
             else:
             else:
-                self.model_info = litellm.get_model_info(self.model_name)
+                self.model_info = litellm.get_model_info(config.model)
         # noinspection PyBroadException
         # noinspection PyBroadException
         except Exception:
         except Exception:
-            logger.warning(f'Could not get model info for {self.model_name}')
-
-        if self.max_input_tokens is None:
-            if self.model_info is not None and 'max_input_tokens' in self.model_info:
-                self.max_input_tokens = self.model_info['max_input_tokens']
+            logger.warning(f'Could not get model info for {config.model}')
+
+        # Set the max tokens in an LM-specific way if not set
+        if config.max_input_tokens is None:
+            if (
+                self.model_info is not None
+                and 'max_input_tokens' in self.model_info
+                and isinstance(self.model_info['max_input_tokens'], int)
+            ):
+                self.config.max_input_tokens = self.model_info['max_input_tokens']
             else:
             else:
                 # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
                 # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
-                self.max_input_tokens = 4096
-
-        if self.max_output_tokens is None:
-            if self.model_info is not None and 'max_output_tokens' in self.model_info:
-                self.max_output_tokens = self.model_info['max_output_tokens']
+                self.config.max_input_tokens = 4096
+
+        if config.max_output_tokens is None:
+            if (
+                self.model_info is not None
+                and 'max_output_tokens' in self.model_info
+                and isinstance(self.model_info['max_output_tokens'], int)
+            ):
+                self.config.max_output_tokens = self.model_info['max_output_tokens']
             else:
             else:
-                # Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
-                # with thousands of unwanted tokens
-                self.max_output_tokens = 1024
+                # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
+                self.config.max_output_tokens = 1024
 
 
         self._completion = partial(
         self._completion = partial(
             litellm_completion,
             litellm_completion,
-            model=self.model_name,
-            api_key=self.api_key,
-            base_url=self.base_url,
-            api_version=self.api_version,
-            custom_llm_provider=custom_llm_provider,
-            max_tokens=self.max_output_tokens,
-            timeout=self.llm_timeout,
-            temperature=llm_temperature,
-            top_p=llm_top_p,
+            model=self.config.model,
+            api_key=self.config.api_key,
+            base_url=self.config.base_url,
+            api_version=self.config.api_version,
+            custom_llm_provider=self.config.custom_llm_provider,
+            max_tokens=self.config.max_output_tokens,
+            timeout=self.config.timeout,
+            temperature=self.config.temperature,
+            top_p=self.config.top_p,
         )
         )
 
 
         completion_unwrapped = self._completion
         completion_unwrapped = self._completion
@@ -197,8 +115,10 @@ class LLM:
 
 
         @retry(
         @retry(
             reraise=True,
             reraise=True,
-            stop=stop_after_attempt(num_retries),
-            wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
+            stop=stop_after_attempt(config.num_retries),
+            wait=wait_random_exponential(
+                min=config.retry_min_wait, max=config.retry_max_wait
+            ),
             retry=retry_if_exception_type(
             retry=retry_if_exception_type(
                 (
                 (
                     RateLimitError,
                     RateLimitError,
@@ -267,7 +187,7 @@ class LLM:
         Returns:
         Returns:
             int: The number of tokens.
             int: The number of tokens.
         """
         """
-        return litellm.token_counter(model=self.model_name, messages=messages)
+        return litellm.token_counter(model=self.config.model, messages=messages)
 
 
     def is_local(self):
     def is_local(self):
         """Determines if the system is using a locally running LLM.
         """Determines if the system is using a locally running LLM.
@@ -275,12 +195,12 @@ class LLM:
         Returns:
         Returns:
             boolean: True if executing a local model.
             boolean: True if executing a local model.
         """
         """
-        if self.base_url is not None:
+        if self.config.base_url is not None:
             for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
             for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
-                if substring in self.base_url:
+                if substring in self.config.base_url:
                     return True
                     return True
-        elif self.model_name is not None:
-            if self.model_name.startswith('ollama'):
+        elif self.config.model is not None:
+            if self.config.model.startswith('ollama'):
                 return True
                 return True
         return False
         return False
 
 
@@ -299,12 +219,12 @@ class LLM:
 
 
         extra_kwargs = {}
         extra_kwargs = {}
         if (
         if (
-            self.input_cost_per_token is not None
-            and self.output_cost_per_token is not None
+            self.config.input_cost_per_token is not None
+            and self.config.output_cost_per_token is not None
         ):
         ):
             cost_per_token = CostPerToken(
             cost_per_token = CostPerToken(
-                input_cost_per_token=self.input_cost_per_token,
-                output_cost_per_token=self.output_cost_per_token,
+                input_cost_per_token=self.config.input_cost_per_token,
+                output_cost_per_token=self.config.output_cost_per_token,
             )
             )
             logger.info(f'Using custom cost per token: {cost_per_token}')
             logger.info(f'Using custom cost per token: {cost_per_token}')
             extra_kwargs['custom_cost_per_token'] = cost_per_token
             extra_kwargs['custom_cost_per_token'] = cost_per_token
@@ -322,11 +242,11 @@ class LLM:
         return 0.0
         return 0.0
 
 
     def __str__(self):
     def __str__(self):
-        if self.api_version:
-            return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})'
-        elif self.base_url:
-            return f'LLM(model={self.model_name}, base_url={self.base_url})'
-        return f'LLM(model={self.model_name})'
+        if self.config.api_version:
+            return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
+        elif self.config.base_url:
+            return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
+        return f'LLM(model={self.config.model})'
 
 
     def __repr__(self):
     def __repr__(self):
         return str(self)
         return str(self)

+ 1 - 1
opendevin/server/session/agent.py

@@ -97,7 +97,7 @@ class AgentSession:
 
 
         # TODO: override other LLM config & agent config groups (#2075)
         # TODO: override other LLM config & agent config groups (#2075)
 
 
-        llm = LLM(llm_config=config.get_llm_config_from_agent(agent_cls))
+        llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
         agent = Agent.get_cls(agent_cls)(llm)
         agent = Agent.get_cls(agent_cls)(llm)
         logger.info(f'Creating agent {agent.name} using LLM {llm}')
         logger.info(f'Creating agent {agent.name} using LLM {llm}')
         if isinstance(agent, CodeActAgent):
         if isinstance(agent, CodeActAgent):

+ 10 - 8
tests/integration/test_agent.py

@@ -7,7 +7,7 @@ import pytest
 
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
 from opendevin.controller.state.state import State
-from opendevin.core.config import parse_arguments
+from opendevin.core.config import LLMConfig, parse_arguments
 from opendevin.core.main import run_agent_controller
 from opendevin.core.main import run_agent_controller
 from opendevin.core.schema import AgentState
 from opendevin.core.schema import AgentState
 from opendevin.events.action import (
 from opendevin.events.action import (
@@ -44,20 +44,22 @@ print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
     os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
     os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
     reason='Manager agent is not capable of finishing this in reasonable steps yet',
     reason='Manager agent is not capable of finishing this in reasonable steps yet',
 )
 )
-def test_write_simple_script():
+def test_write_simple_script() -> None:
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     args = parse_arguments()
     args = parse_arguments()
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     final_state: State | None = asyncio.run(
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
         run_agent_controller(agent, task, exit_on_message=True)
     )
     )
+    assert final_state is not None
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.last_error is None
     assert final_state.last_error is None
 
 
     # Verify the script file exists
     # Verify the script file exists
+    assert workspace_base is not None
     script_path = os.path.join(workspace_base, 'hello.sh')
     script_path = os.path.join(workspace_base, 'hello.sh')
     assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
     assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
 
 
@@ -103,7 +105,7 @@ def test_edits():
         shutil.copy(os.path.join(source_dir, file), dest_file)
         shutil.copy(os.path.join(source_dir, file), dest_file)
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     # Execute the task
     # Execute the task
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
@@ -137,7 +139,7 @@ def test_ipython():
     args = parse_arguments()
     args = parse_arguments()
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     # Execute the task
     # Execute the task
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
@@ -171,7 +173,7 @@ def test_simple_task_rejection():
     args = parse_arguments()
     args = parse_arguments()
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     # Give an impossible task to do: cannot write a commit message because
     # Give an impossible task to do: cannot write a commit message because
     # the workspace is not a git repo
     # the workspace is not a git repo
@@ -195,7 +197,7 @@ def test_ipython_module():
     args = parse_arguments()
     args = parse_arguments()
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     # Execute the task
     # Execute the task
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
@@ -235,7 +237,7 @@ def test_browse_internet(http_server):
     args = parse_arguments()
     args = parse_arguments()
 
 
     # Create the agent
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
 
     # Execute the task
     # Execute the task
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'

+ 4 - 5
tests/unit/test_action_serialization.py

@@ -1,4 +1,3 @@
-from opendevin.core.config import config
 from opendevin.events.action import (
 from opendevin.events.action import (
     Action,
     Action,
     AddTaskAction,
     AddTaskAction,
@@ -20,7 +19,9 @@ from opendevin.events.serialization import (
 )
 )
 
 
 
 
-def serialization_deserialization(original_action_dict, cls):
+def serialization_deserialization(
+    original_action_dict, cls, max_message_chars: int = 10000
+):
     action_instance = event_from_dict(original_action_dict)
     action_instance = event_from_dict(original_action_dict)
     assert isinstance(
     assert isinstance(
         action_instance, Action
         action_instance, Action
@@ -29,9 +30,7 @@ def serialization_deserialization(original_action_dict, cls):
         action_instance, cls
         action_instance, cls
     ), f'The action instance should be an instance of {cls.__name__}.'
     ), f'The action instance should be an instance of {cls.__name__}.'
     serialized_action_dict = event_to_dict(action_instance)
     serialized_action_dict = event_to_dict(action_instance)
-    serialized_action_memory = event_to_memory(
-        action_instance, config.get_llm_config().max_message_chars
-    )
+    serialized_action_memory = event_to_memory(action_instance, max_message_chars)
     serialized_action_dict.pop('message')
     serialized_action_dict.pop('message')
     assert (
     assert (
         serialized_action_dict == original_action_dict
         serialized_action_dict == original_action_dict

+ 4 - 3
tests/unit/test_observation_serialization.py

@@ -1,4 +1,3 @@
-from opendevin.core.config import config
 from opendevin.events.observation import (
 from opendevin.events.observation import (
     CmdOutputObservation,
     CmdOutputObservation,
     Observation,
     Observation,
@@ -10,7 +9,9 @@ from opendevin.events.serialization import (
 )
 )
 
 
 
 
-def serialization_deserialization(original_observation_dict, cls):
+def serialization_deserialization(
+    original_observation_dict, cls, max_message_chars: int = 10000
+):
     observation_instance = event_from_dict(original_observation_dict)
     observation_instance = event_from_dict(original_observation_dict)
     assert isinstance(
     assert isinstance(
         observation_instance, Observation
         observation_instance, Observation
@@ -20,7 +21,7 @@ def serialization_deserialization(original_observation_dict, cls):
     ), 'The observation instance should be an instance of CmdOutputObservation.'
     ), 'The observation instance should be an instance of CmdOutputObservation.'
     serialized_observation_dict = event_to_dict(observation_instance)
     serialized_observation_dict = event_to_dict(observation_instance)
     serialized_observation_memory = event_to_memory(
     serialized_observation_memory = event_to_memory(
-        observation_instance, config.get_llm_config().max_message_chars
+        observation_instance, max_message_chars
     )
     )
     assert (
     assert (
         serialized_observation_dict == original_observation_dict
         serialized_observation_dict == original_observation_dict