Browse Source

Refactor LLM config (#2953)

* Add max_message_chars to LLM

* Refactor LLM config

* Fix tests

* Made some functions class functions

* Fix regression

* Fixed comments
Graham Neubig 1 year ago
parent
commit
c897791024

+ 62 - 62
agenthub/codeact_agent/codeact_agent.py

@@ -8,7 +8,6 @@ from agenthub.codeact_agent.prompt import (
 )
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.events.action import (
     Action,
     AgentDelegateAction,
@@ -22,6 +21,7 @@ from opendevin.events.observation import (
     CmdOutputObservation,
     IPythonRunCellObservation,
 )
+from opendevin.events.observation.observation import Observation
 from opendevin.events.serialization.event import truncate_content
 from opendevin.llm.llm import LLM
 from opendevin.runtime.plugins import (
@@ -34,62 +34,6 @@ from opendevin.runtime.tools import RuntimeTool
 ENABLE_GITHUB = True
 
 
-def action_to_str(action: Action) -> str:
-    if isinstance(action, CmdRunAction):
-        return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
-    elif isinstance(action, IPythonRunCellAction):
-        return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
-    elif isinstance(action, AgentDelegateAction):
-        return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
-    elif isinstance(action, MessageAction):
-        return action.content
-    return ''
-
-
-def get_action_message(action: Action) -> dict[str, str] | None:
-    if (
-        isinstance(action, AgentDelegateAction)
-        or isinstance(action, CmdRunAction)
-        or isinstance(action, IPythonRunCellAction)
-        or isinstance(action, MessageAction)
-    ):
-        return {
-            'role': 'user' if action.source == 'user' else 'assistant',
-            'content': action_to_str(action),
-        }
-    return None
-
-
-def get_observation_message(obs) -> dict[str, str] | None:
-    max_message_chars = config.get_llm_config_from_agent(
-        'CodeActAgent'
-    ).max_message_chars
-    if isinstance(obs, CmdOutputObservation):
-        content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
-        content += (
-            f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
-        )
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, IPythonRunCellObservation):
-        content = 'OBSERVATION:\n' + obs.content
-        # replace base64 images with a placeholder
-        splitted = content.split('\n')
-        for i, line in enumerate(splitted):
-            if '![image](data:image/png;base64,' in line:
-                splitted[i] = (
-                    '![image](data:image/png;base64, ...) already displayed to user'
-                )
-        content = '\n'.join(splitted)
-        content = truncate_content(content, max_message_chars)
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, AgentDelegateObservation):
-        content = 'OBSERVATION:\n' + truncate_content(
-            str(obs.outputs), max_message_chars
-        )
-        return {'role': 'user', 'content': content}
-    return None
-
-
 # FIXME: We can tweak these two settings to create MicroAgents specialized toward different area
 def get_system_message() -> str:
     if ENABLE_GITHUB:
@@ -166,6 +110,61 @@ class CodeActAgent(Agent):
         super().__init__(llm)
         self.reset()
 
+    def action_to_str(self, action: Action) -> str:
+        if isinstance(action, CmdRunAction):
+            return (
+                f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
+            )
+        elif isinstance(action, IPythonRunCellAction):
+            return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
+        elif isinstance(action, AgentDelegateAction):
+            return f'{action.thought}\n<execute_browse>\n{action.inputs["task"]}\n</execute_browse>'
+        elif isinstance(action, MessageAction):
+            return action.content
+        return ''
+
+    def get_action_message(self, action: Action) -> dict[str, str] | None:
+        if (
+            isinstance(action, AgentDelegateAction)
+            or isinstance(action, CmdRunAction)
+            or isinstance(action, IPythonRunCellAction)
+            or isinstance(action, MessageAction)
+        ):
+            return {
+                'role': 'user' if action.source == 'user' else 'assistant',
+                'content': self.action_to_str(action),
+            }
+        return None
+
+    def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
+        max_message_chars = self.llm.config.max_message_chars
+        if isinstance(obs, CmdOutputObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                obs.content, max_message_chars
+            )
+            content += (
+                f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
+            )
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, IPythonRunCellObservation):
+            content = 'OBSERVATION:\n' + obs.content
+            # replace base64 images with a placeholder
+            splitted = content.split('\n')
+            for i, line in enumerate(splitted):
+                if '![image](data:image/png;base64,' in line:
+                    splitted[i] = (
+                        '![image](data:image/png;base64, ...) already displayed to user'
+                    )
+            content = '\n'.join(splitted)
+            content = truncate_content(content, max_message_chars)
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, AgentDelegateObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                str(obs.outputs), max_message_chars
+            )
+            return {'role': 'user', 'content': content}
+        return None
+
     def reset(self) -> None:
         """Resets the CodeAct Agent."""
         super().reset()
@@ -211,11 +210,12 @@ class CodeActAgent(Agent):
 
         for event in state.history.get_events():
             # create a regular message from an event
-            message = (
-                get_action_message(event)
-                if isinstance(event, Action)
-                else get_observation_message(event)
-            )
+            if isinstance(event, Action):
+                message = self.get_action_message(event)
+            elif isinstance(event, Observation):
+                message = self.get_observation_message(event)
+            else:
+                raise ValueError(f'Unknown event type: {type(event)}')
 
             # add regular message
             if message:

+ 54 - 54
agenthub/codeact_swe_agent/codeact_swe_agent.py

@@ -7,7 +7,6 @@ from agenthub.codeact_swe_agent.prompt import (
 from agenthub.codeact_swe_agent.response_parser import CodeActSWEResponseParser
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.events.action import (
     Action,
     AgentFinishAction,
@@ -19,6 +18,7 @@ from opendevin.events.observation import (
     CmdOutputObservation,
     IPythonRunCellObservation,
 )
+from opendevin.events.observation.observation import Observation
 from opendevin.events.serialization.event import truncate_content
 from opendevin.llm.llm import LLM
 from opendevin.runtime.plugins import (
@@ -29,54 +29,6 @@ from opendevin.runtime.plugins import (
 from opendevin.runtime.tools import RuntimeTool
 
 
-def action_to_str(action: Action) -> str:
-    if isinstance(action, CmdRunAction):
-        return f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
-    elif isinstance(action, IPythonRunCellAction):
-        return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
-    elif isinstance(action, MessageAction):
-        return action.content
-    return ''
-
-
-def get_action_message(action: Action) -> dict[str, str] | None:
-    if (
-        isinstance(action, CmdRunAction)
-        or isinstance(action, IPythonRunCellAction)
-        or isinstance(action, MessageAction)
-    ):
-        return {
-            'role': 'user' if action.source == 'user' else 'assistant',
-            'content': action_to_str(action),
-        }
-    return None
-
-
-def get_observation_message(obs) -> dict[str, str] | None:
-    max_message_chars = config.get_llm_config_from_agent(
-        'CodeActSWEAgent'
-    ).max_message_chars
-    if isinstance(obs, CmdOutputObservation):
-        content = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
-        content += (
-            f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
-        )
-        return {'role': 'user', 'content': content}
-    elif isinstance(obs, IPythonRunCellObservation):
-        content = 'OBSERVATION:\n' + obs.content
-        # replace base64 images with a placeholder
-        splitted = content.split('\n')
-        for i, line in enumerate(splitted):
-            if '![image](data:image/png;base64,' in line:
-                splitted[i] = (
-                    '![image](data:image/png;base64, ...) already displayed to user'
-                )
-        content = '\n'.join(splitted)
-        content = truncate_content(content, max_message_chars)
-        return {'role': 'user', 'content': content}
-    return None
-
-
 def get_system_message() -> str:
     return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}'
 
@@ -121,6 +73,53 @@ class CodeActSWEAgent(Agent):
         super().__init__(llm)
         self.reset()
 
+    def action_to_str(self, action: Action) -> str:
+        if isinstance(action, CmdRunAction):
+            return (
+                f'{action.thought}\n<execute_bash>\n{action.command}\n</execute_bash>'
+            )
+        elif isinstance(action, IPythonRunCellAction):
+            return f'{action.thought}\n<execute_ipython>\n{action.code}\n</execute_ipython>'
+        elif isinstance(action, MessageAction):
+            return action.content
+        return ''
+
+    def get_action_message(self, action: Action) -> dict[str, str] | None:
+        if (
+            isinstance(action, CmdRunAction)
+            or isinstance(action, IPythonRunCellAction)
+            or isinstance(action, MessageAction)
+        ):
+            return {
+                'role': 'user' if action.source == 'user' else 'assistant',
+                'content': self.action_to_str(action),
+            }
+        return None
+
+    def get_observation_message(self, obs: Observation) -> dict[str, str] | None:
+        max_message_chars = self.llm.config.max_message_chars
+        if isinstance(obs, CmdOutputObservation):
+            content = 'OBSERVATION:\n' + truncate_content(
+                obs.content, max_message_chars
+            )
+            content += (
+                f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]'
+            )
+            return {'role': 'user', 'content': content}
+        elif isinstance(obs, IPythonRunCellObservation):
+            content = 'OBSERVATION:\n' + obs.content
+            # replace base64 images with a placeholder
+            splitted = content.split('\n')
+            for i, line in enumerate(splitted):
+                if '![image](data:image/png;base64,' in line:
+                    splitted[i] = (
+                        '![image](data:image/png;base64, ...) already displayed to user'
+                    )
+            content = '\n'.join(splitted)
+            content = truncate_content(content, max_message_chars)
+            return {'role': 'user', 'content': content}
+        return None
+
     def reset(self) -> None:
         """Resets the CodeAct Agent."""
         super().reset()
@@ -165,11 +164,12 @@ class CodeActSWEAgent(Agent):
 
         for event in state.history.get_events():
             # create a regular message from an event
-            message = (
-                get_action_message(event)
-                if isinstance(event, Action)
-                else get_observation_message(event)
-            )
+            if isinstance(event, Action):
+                message = self.get_action_message(event)
+            elif isinstance(event, Observation):
+                message = self.get_observation_message(event)
+            else:
+                raise ValueError(f'Unknown event type: {type(event)}')
 
             # add regular message
             if message:

+ 23 - 23
agenthub/micro/agent.py

@@ -2,7 +2,6 @@ from jinja2 import BaseLoader, Environment
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.core.utils import json
 from opendevin.events.action import Action
 from opendevin.events.serialization.action import action_from_dict
@@ -27,32 +26,33 @@ def to_json(obj, **kwargs):
     return json.dumps(obj, **kwargs)
 
 
-def history_to_json(history: ShortTermHistory, max_events=20, **kwargs):
-    """Serialize and simplify history to str format"""
-    # TODO: get agent specific llm config
-    llm_config = config.get_llm_config()
-    max_message_chars = llm_config.max_message_chars
-
-    processed_history = []
-    event_count = 0
-
-    for event in history.get_events(reverse=True):
-        if event_count >= max_events:
-            break
-        processed_history.append(event_to_memory(event, max_message_chars))
-        event_count += 1
-
-    # history is in reverse order, let's fix it
-    processed_history.reverse()
-
-    return json.dumps(processed_history, **kwargs)
-
-
 class MicroAgent(Agent):
     VERSION = '1.0'
     prompt = ''
     agent_definition: dict = {}
 
+    def history_to_json(
+        self, history: ShortTermHistory, max_events: int = 20, **kwargs
+    ):
+        """
+        Serialize and simplify history to str format
+        """
+        processed_history = []
+        event_count = 0
+
+        for event in history.get_events(reverse=True):
+            if event_count >= max_events:
+                break
+            processed_history.append(
+                event_to_memory(event, self.llm.config.max_message_chars)
+            )
+            event_count += 1
+
+        # history is in reverse order, let's fix it
+        processed_history.reverse()
+
+        return json.dumps(processed_history, **kwargs)
+
     def __init__(self, llm: LLM):
         super().__init__(llm)
         if 'name' not in self.agent_definition:
@@ -66,7 +66,7 @@ class MicroAgent(Agent):
             state=state,
             instructions=instructions,
             to_json=to_json,
-            history_to_json=history_to_json,
+            history_to_json=self.history_to_json,
             delegates=self.delegates,
             latest_user_message=state.get_current_user_intent(),
         )

+ 12 - 12
agenthub/monologue_agent/agent.py

@@ -83,10 +83,7 @@ class MonologueAgent(Agent):
         self._add_initial_thoughts(task)
         self._initialized = True
 
-    def _add_initial_thoughts(self, task):
-        max_message_chars = config.get_llm_config_from_agent(
-            'MonologueAgent'
-        ).max_message_chars
+    def _add_initial_thoughts(self, task: str):
         previous_action = ''
         for thought in INITIAL_THOUGHTS:
             thought = thought.replace('$TASK', task)
@@ -103,7 +100,7 @@ class MonologueAgent(Agent):
                         content=thought, url='', screenshot=''
                     )
                 self.initial_thoughts.append(
-                    event_to_memory(observation, max_message_chars)
+                    event_to_memory(observation, self.llm.config.max_message_chars)
                 )
                 previous_action = ''
             else:
@@ -127,7 +124,9 @@ class MonologueAgent(Agent):
                     previous_action = ActionType.BROWSE
                 else:
                     action = MessageAction(thought)
-                self.initial_thoughts.append(event_to_memory(action, max_message_chars))
+                self.initial_thoughts.append(
+                    event_to_memory(action, self.llm.config.max_message_chars)
+                )
 
     def step(self, state: State) -> Action:
         """Modifies the current state by adding the most recent actions and observations, then prompts the model to think about it's next action to take using monologue, memory, and hint.
@@ -138,9 +137,6 @@ class MonologueAgent(Agent):
         Returns:
         - Action: The next action to take based on LLM response
         """
-        max_message_chars = config.get_llm_config_from_agent(
-            'MonologueAgent'
-        ).max_message_chars
         goal = state.get_current_user_intent()
         self._initialize(goal)
 
@@ -148,7 +144,9 @@ class MonologueAgent(Agent):
 
         # add the events from state.history
         for event in state.history.get_events():
-            recent_events.append(event_to_memory(event, max_message_chars))
+            recent_events.append(
+                event_to_memory(event, self.llm.config.max_message_chars)
+            )
 
         # add the last messages to long term memory
         if self.memory is not None:
@@ -158,10 +156,12 @@ class MonologueAgent(Agent):
             # this should still work
             # we will need to do this differently: find out if there really is an action or an observation in this step
             if last_action:
-                self.memory.add_event(event_to_memory(last_action, max_message_chars))
+                self.memory.add_event(
+                    event_to_memory(last_action, self.llm.config.max_message_chars)
+                )
             if last_observation:
                 self.memory.add_event(
-                    event_to_memory(last_observation, max_message_chars)
+                    event_to_memory(last_observation, self.llm.config.max_message_chars)
                 )
 
         # the action prompt with initial thoughts and recent events

+ 1 - 1
agenthub/planner_agent/agent.py

@@ -42,7 +42,7 @@ class PlannerAgent(Agent):
             'abandoned',
         ]:
             return AgentFinishAction()
-        prompt = get_prompt(state)
+        prompt = get_prompt(state, self.llm.config.max_message_chars)
         messages = [{'content': prompt, 'role': 'user'}]
         resp = self.llm.completion(messages=messages)
         return self.response_parser.parse(resp)

+ 2 - 6
agenthub/planner_agent/prompt.py

@@ -1,5 +1,4 @@
 from opendevin.controller.state.state import State
-from opendevin.core.config import config
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.schema import ActionType
 from opendevin.core.utils import json
@@ -116,8 +115,9 @@ def get_hint(latest_action_id: str) -> str:
     return hints.get(latest_action_id, '')
 
 
-def get_prompt(state: State) -> str:
+def get_prompt(state: State, max_message_chars: int) -> str:
     """Gets the prompt for the planner agent.
+
     Formatted with the most recent action-observation pairs, current task, and hint based on last action
 
     Parameters:
@@ -126,10 +126,6 @@ def get_prompt(state: State) -> str:
     Returns:
     - str: The formatted string prompt with historical values
     """
-    max_message_chars = config.get_llm_config_from_agent(
-        'PlannerAgent'
-    ).max_message_chars
-
     # the plan
     plan_str = json.dumps(state.root_task.to_dict(), indent=2)
 

+ 1 - 1
opendevin/controller/agent_controller.py

@@ -248,7 +248,7 @@ class AgentController:
     async def start_delegate(self, action: AgentDelegateAction):
         agent_cls: Type[Agent] = Agent.get_cls(action.agent)
         llm_config = config.get_llm_config_from_agent(action.agent)
-        llm = LLM(llm_config=llm_config)
+        llm = LLM(config=llm_config)
         delegate_agent = agent_cls(llm=llm)
         state = State(
             inputs=action.inputs or {},

+ 1 - 1
opendevin/core/config.py

@@ -198,7 +198,7 @@ class AppConfig(metaclass=Singleton):
         file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed.
     """
 
-    llms: dict = field(default_factory=dict)
+    llms: dict[str, LLMConfig] = field(default_factory=dict)
     agents: dict = field(default_factory=dict)
     default_agent: str = 'CodeActAgent'
     sandbox: SandboxConfig = field(default_factory=SandboxConfig)

+ 2 - 2
opendevin/core/main.py

@@ -52,7 +52,7 @@ async def run_agent_controller(
     """
     # Logging
     logger.info(
-        f'Running agent {agent.name}, model {agent.llm.model_name}, with task: "{task_str}"'
+        f'Running agent {agent.name}, model {agent.llm.config.model}, with task: "{task_str}"'
     )
 
     # set up the event stream
@@ -163,7 +163,7 @@ if __name__ == '__main__':
         if llm_config is None:
             raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
         config.set_llm_config(llm_config)
-    llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
+    llm = LLM(config=config.get_llm_config_from_agent(args.agent_cls))
 
     # Create the agent
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)

+ 62 - 142
opendevin/llm/llm.py

@@ -1,6 +1,9 @@
+import copy
 import warnings
 from functools import partial
 
+from opendevin.core.config import LLMConfig
+
 with warnings.catch_warnings():
     warnings.simplefilter('ignore')
     import litellm
@@ -21,7 +24,6 @@ from tenacity import (
     wait_random_exponential,
 )
 
-from opendevin.core.config import config
 from opendevin.core.logger import llm_prompt_logger, llm_response_logger
 from opendevin.core.logger import opendevin_logger as logger
 from opendevin.core.metrics import Metrics
@@ -35,155 +37,71 @@ class LLM:
     """The LLM class represents a Language Model instance.
 
     Attributes:
-        model_name (str): The name of the language model.
-        api_key (str): The API key for accessing the language model.
-        base_url (str): The base URL for the language model API.
-        api_version (str): The version of the API to use.
-        max_input_tokens (int): The maximum number of tokens to send to the LLM per task.
-        max_output_tokens (int): The maximum number of tokens to receive from the LLM per task.
-        llm_timeout (int): The maximum time to wait for a response in seconds.
-        custom_llm_provider (str): A custom LLM provider.
+        config: an LLMConfig object specifying the configuration of the LLM.
     """
 
     def __init__(
         self,
-        model=None,
-        api_key=None,
-        base_url=None,
-        api_version=None,
-        num_retries=None,
-        retry_min_wait=None,
-        retry_max_wait=None,
-        llm_timeout=None,
-        llm_temperature=None,
-        llm_top_p=None,
-        custom_llm_provider=None,
-        max_input_tokens=None,
-        max_output_tokens=None,
-        llm_config=None,
-        metrics=None,
-        cost_metric_supported=True,
-        input_cost_per_token=None,
-        output_cost_per_token=None,
+        config: LLMConfig,
+        metrics: Metrics | None = None,
     ):
         """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
 
         Passing simple parameters always overrides config.
 
         Args:
-            model (str, optional): The name of the language model. Defaults to LLM_MODEL.
-            api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
-            base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
-            api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
-            num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
-            retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
-            retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
-            max_input_tokens (int, optional): The maximum number of tokens to send to the LLM per task. Defaults to LLM_MAX_INPUT_TOKENS.
-            max_output_tokens (int, optional): The maximum number of tokens to receive from the LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS.
-            custom_llm_provider (str, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER.
-            llm_timeout (int, optional): The maximum time to wait for a response in seconds. Defaults to LLM_TIMEOUT.
-            llm_temperature (float, optional): The temperature for LLM sampling. Defaults to LLM_TEMPERATURE.
-            metrics (Metrics, optional): The metrics object to use. Defaults to None.
-            cost_metric_supported (bool, optional): Whether the cost metric is supported. Defaults to True.
-            input_cost_per_token (float, optional): The cost per input token.
-            output_cost_per_token (float, optional): The cost per output token.
+            config: The LLM configuration
         """
-        if llm_config is None:
-            llm_config = config.get_llm_config()
-        model = model if model is not None else llm_config.model
-        api_key = api_key if api_key is not None else llm_config.api_key
-        base_url = base_url if base_url is not None else llm_config.base_url
-        api_version = api_version if api_version is not None else llm_config.api_version
-        num_retries = num_retries if num_retries is not None else llm_config.num_retries
-        retry_min_wait = (
-            retry_min_wait if retry_min_wait is not None else llm_config.retry_min_wait
-        )
-        retry_max_wait = (
-            retry_max_wait if retry_max_wait is not None else llm_config.retry_max_wait
-        )
-        llm_timeout = llm_timeout if llm_timeout is not None else llm_config.timeout
-        llm_temperature = (
-            llm_temperature if llm_temperature is not None else llm_config.temperature
-        )
-        llm_top_p = llm_top_p if llm_top_p is not None else llm_config.top_p
-        custom_llm_provider = (
-            custom_llm_provider
-            if custom_llm_provider is not None
-            else llm_config.custom_llm_provider
-        )
-        max_input_tokens = (
-            max_input_tokens
-            if max_input_tokens is not None
-            else llm_config.max_input_tokens
-        )
-        max_output_tokens = (
-            max_output_tokens
-            if max_output_tokens is not None
-            else llm_config.max_output_tokens
-        )
-        input_cost_per_token = (
-            input_cost_per_token
-            if input_cost_per_token is not None
-            else llm_config.input_cost_per_token
-        )
-        output_cost_per_token = (
-            output_cost_per_token
-            if output_cost_per_token is not None
-            else llm_config.output_cost_per_token
-        )
-        metrics = metrics if metrics is not None else Metrics()
-
-        logger.info(f'Initializing LLM with model: {model}')
-        self.model_name = model
-        self.api_key = api_key
-        self.base_url = base_url
-        self.api_version = api_version
-        self.max_input_tokens = max_input_tokens
-        self.max_output_tokens = max_output_tokens
-        self.input_cost_per_token = input_cost_per_token
-        self.output_cost_per_token = output_cost_per_token
-        self.llm_timeout = llm_timeout
-        self.custom_llm_provider = custom_llm_provider
-        self.metrics = metrics
-        self.cost_metric_supported = cost_metric_supported
+
+        self.config = copy.deepcopy(config)
+        self.metrics = metrics if metrics is not None else Metrics()
+        self.cost_metric_supported = True
 
         # litellm actually uses base Exception here for unknown model
         self.model_info = None
         try:
-            if not self.model_name.startswith('openrouter'):
-                self.model_info = litellm.get_model_info(self.model_name.split(':')[0])
+            if not config.model.startswith('openrouter'):
+                self.model_info = litellm.get_model_info(config.model.split(':')[0])
             else:
-                self.model_info = litellm.get_model_info(self.model_name)
+                self.model_info = litellm.get_model_info(config.model)
         # noinspection PyBroadException
         except Exception:
-            logger.warning(f'Could not get model info for {self.model_name}')
-
-        if self.max_input_tokens is None:
-            if self.model_info is not None and 'max_input_tokens' in self.model_info:
-                self.max_input_tokens = self.model_info['max_input_tokens']
+            logger.warning(f'Could not get model info for {config.model}')
+
+        # Set the max tokens in an LM-specific way if not set
+        if config.max_input_tokens is None:
+            if (
+                self.model_info is not None
+                and 'max_input_tokens' in self.model_info
+                and isinstance(self.model_info['max_input_tokens'], int)
+            ):
+                self.config.max_input_tokens = self.model_info['max_input_tokens']
             else:
                 # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
-                self.max_input_tokens = 4096
-
-        if self.max_output_tokens is None:
-            if self.model_info is not None and 'max_output_tokens' in self.model_info:
-                self.max_output_tokens = self.model_info['max_output_tokens']
+                self.config.max_input_tokens = 4096
+
+        if config.max_output_tokens is None:
+            if (
+                self.model_info is not None
+                and 'max_output_tokens' in self.model_info
+                and isinstance(self.model_info['max_output_tokens'], int)
+            ):
+                self.config.max_output_tokens = self.model_info['max_output_tokens']
             else:
-                # Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
-                # with thousands of unwanted tokens
-                self.max_output_tokens = 1024
+                # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
+                self.config.max_output_tokens = 1024
 
         self._completion = partial(
             litellm_completion,
-            model=self.model_name,
-            api_key=self.api_key,
-            base_url=self.base_url,
-            api_version=self.api_version,
-            custom_llm_provider=custom_llm_provider,
-            max_tokens=self.max_output_tokens,
-            timeout=self.llm_timeout,
-            temperature=llm_temperature,
-            top_p=llm_top_p,
+            model=self.config.model,
+            api_key=self.config.api_key,
+            base_url=self.config.base_url,
+            api_version=self.config.api_version,
+            custom_llm_provider=self.config.custom_llm_provider,
+            max_tokens=self.config.max_output_tokens,
+            timeout=self.config.timeout,
+            temperature=self.config.temperature,
+            top_p=self.config.top_p,
         )
 
         completion_unwrapped = self._completion
@@ -197,8 +115,10 @@ class LLM:
 
         @retry(
             reraise=True,
-            stop=stop_after_attempt(num_retries),
-            wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
+            stop=stop_after_attempt(config.num_retries),
+            wait=wait_random_exponential(
+                min=config.retry_min_wait, max=config.retry_max_wait
+            ),
             retry=retry_if_exception_type(
                 (
                     RateLimitError,
@@ -267,7 +187,7 @@ class LLM:
         Returns:
             int: The number of tokens.
         """
-        return litellm.token_counter(model=self.model_name, messages=messages)
+        return litellm.token_counter(model=self.config.model, messages=messages)
 
     def is_local(self):
         """Determines if the system is using a locally running LLM.
@@ -275,12 +195,12 @@ class LLM:
         Returns:
             boolean: True if executing a local model.
         """
-        if self.base_url is not None:
+        if self.config.base_url is not None:
             for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
-                if substring in self.base_url:
+                if substring in self.config.base_url:
                     return True
-        elif self.model_name is not None:
-            if self.model_name.startswith('ollama'):
+        elif self.config.model is not None:
+            if self.config.model.startswith('ollama'):
                 return True
         return False
 
@@ -299,12 +219,12 @@ class LLM:
 
         extra_kwargs = {}
         if (
-            self.input_cost_per_token is not None
-            and self.output_cost_per_token is not None
+            self.config.input_cost_per_token is not None
+            and self.config.output_cost_per_token is not None
         ):
             cost_per_token = CostPerToken(
-                input_cost_per_token=self.input_cost_per_token,
-                output_cost_per_token=self.output_cost_per_token,
+                input_cost_per_token=self.config.input_cost_per_token,
+                output_cost_per_token=self.config.output_cost_per_token,
             )
             logger.info(f'Using custom cost per token: {cost_per_token}')
             extra_kwargs['custom_cost_per_token'] = cost_per_token
@@ -322,11 +242,11 @@ class LLM:
         return 0.0
 
     def __str__(self):
-        if self.api_version:
-            return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})'
-        elif self.base_url:
-            return f'LLM(model={self.model_name}, base_url={self.base_url})'
-        return f'LLM(model={self.model_name})'
+        if self.config.api_version:
+            return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
+        elif self.config.base_url:
+            return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
+        return f'LLM(model={self.config.model})'
 
     def __repr__(self):
         return str(self)

+ 1 - 1
opendevin/server/session/agent.py

@@ -97,7 +97,7 @@ class AgentSession:
 
         # TODO: override other LLM config & agent config groups (#2075)
 
-        llm = LLM(llm_config=config.get_llm_config_from_agent(agent_cls))
+        llm = LLM(config=config.get_llm_config_from_agent(agent_cls))
         agent = Agent.get_cls(agent_cls)(llm)
         logger.info(f'Creating agent {agent.name} using LLM {llm}')
         if isinstance(agent, CodeActAgent):

+ 10 - 8
tests/integration/test_agent.py

@@ -7,7 +7,7 @@ import pytest
 
 from opendevin.controller.agent import Agent
 from opendevin.controller.state.state import State
-from opendevin.core.config import parse_arguments
+from opendevin.core.config import LLMConfig, parse_arguments
 from opendevin.core.main import run_agent_controller
 from opendevin.core.schema import AgentState
 from opendevin.events.action import (
@@ -44,20 +44,22 @@ print(f'workspace_mount_path_in_sandbox: {workspace_mount_path_in_sandbox}')
     os.getenv('DEFAULT_AGENT') == 'ManagerAgent',
     reason='Manager agent is not capable of finishing this in reasonable steps yet',
 )
-def test_write_simple_script():
+def test_write_simple_script() -> None:
     task = "Write a shell script 'hello.sh' that prints 'hello'. Do not ask me for confirmation at any point."
     args = parse_arguments()
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     final_state: State | None = asyncio.run(
         run_agent_controller(agent, task, exit_on_message=True)
     )
+    assert final_state is not None
     assert final_state.agent_state == AgentState.STOPPED
     assert final_state.last_error is None
 
     # Verify the script file exists
+    assert workspace_base is not None
     script_path = os.path.join(workspace_base, 'hello.sh')
     assert os.path.exists(script_path), 'The file "hello.sh" does not exist'
 
@@ -103,7 +105,7 @@ def test_edits():
         shutil.copy(os.path.join(source_dir, file), dest_file)
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     # Execute the task
     task = 'Fix typos in bad.txt. Do not ask me for confirmation at any point.'
@@ -137,7 +139,7 @@ def test_ipython():
     args = parse_arguments()
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     # Execute the task
     task = "Use Jupyter IPython to write a text file containing 'hello world' to '/workspace/test.txt'. Do not ask me for confirmation at any point."
@@ -171,7 +173,7 @@ def test_simple_task_rejection():
     args = parse_arguments()
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     # Give an impossible task to do: cannot write a commit message because
     # the workspace is not a git repo
@@ -195,7 +197,7 @@ def test_ipython_module():
     args = parse_arguments()
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     # Execute the task
     task = "Install and import pymsgbox==1.0.9 and print it's version in /workspace/test.txt. Do not ask me for confirmation at any point."
@@ -235,7 +237,7 @@ def test_browse_internet(http_server):
     args = parse_arguments()
 
     # Create the agent
-    agent = Agent.get_cls(args.agent_cls)(llm=LLM())
+    agent = Agent.get_cls(args.agent_cls)(llm=LLM(LLMConfig()))
 
     # Execute the task
     task = 'Browse localhost:8000, and tell me the ultimate answer to life. Do not ask me for confirmation at any point.'

+ 4 - 5
tests/unit/test_action_serialization.py

@@ -1,4 +1,3 @@
-from opendevin.core.config import config
 from opendevin.events.action import (
     Action,
     AddTaskAction,
@@ -20,7 +19,9 @@ from opendevin.events.serialization import (
 )
 
 
-def serialization_deserialization(original_action_dict, cls):
+def serialization_deserialization(
+    original_action_dict, cls, max_message_chars: int = 10000
+):
     action_instance = event_from_dict(original_action_dict)
     assert isinstance(
         action_instance, Action
@@ -29,9 +30,7 @@ def serialization_deserialization(original_action_dict, cls):
         action_instance, cls
     ), f'The action instance should be an instance of {cls.__name__}.'
     serialized_action_dict = event_to_dict(action_instance)
-    serialized_action_memory = event_to_memory(
-        action_instance, config.get_llm_config().max_message_chars
-    )
+    serialized_action_memory = event_to_memory(action_instance, max_message_chars)
     serialized_action_dict.pop('message')
     assert (
         serialized_action_dict == original_action_dict

+ 4 - 3
tests/unit/test_observation_serialization.py

@@ -1,4 +1,3 @@
-from opendevin.core.config import config
 from opendevin.events.observation import (
     CmdOutputObservation,
     Observation,
@@ -10,7 +9,9 @@ from opendevin.events.serialization import (
 )
 
 
-def serialization_deserialization(original_observation_dict, cls):
+def serialization_deserialization(
+    original_observation_dict, cls, max_message_chars: int = 10000
+):
     observation_instance = event_from_dict(original_observation_dict)
     assert isinstance(
         observation_instance, Observation
@@ -20,7 +21,7 @@ def serialization_deserialization(original_observation_dict, cls):
     ), 'The observation instance should be an instance of CmdOutputObservation.'
     serialized_observation_dict = event_to_dict(observation_instance)
     serialized_observation_memory = event_to_memory(
-        observation_instance, config.get_llm_config().max_message_chars
+        observation_instance, max_message_chars
     )
     assert (
         serialized_observation_dict == original_observation_dict