Procházet zdrojové kódy

Rework --llm-config CLI arg (#2957)

Boxuan Li před 1 rokem
rodič
revize
e3e437fcc2

+ 1 - 1
opendevin/core/config.py

@@ -656,7 +656,7 @@ def get_parser() -> argparse.ArgumentParser:
         '--llm-config',
         default=None,
         type=str,
-        help='The group of llm settings, e.g. "llama3" for [llm.llama3] section in the toml file. Overrides model if both are provided.',
+        help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml',
     )
     return parser
 

+ 3 - 4
opendevin/core/main.py

@@ -157,14 +157,13 @@ if __name__ == '__main__':
     else:
         raise ValueError('No task provided. Please specify a task through -t, -f.')
 
-    # Figure out the LLM config
+    # Override default LLM configs ([llm] section in config.toml)
     if args.llm_config:
         llm_config = get_llm_config_arg(args.llm_config)
         if llm_config is None:
             raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
-        llm = LLM(llm_config=llm_config)
-    else:
-        llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
+        config.set_llm_config(llm_config)
+    llm = LLM(llm_config=config.get_llm_config_from_agent(args.agent_cls))
 
     # Create the agent
     AgentCls: Type[Agent] = Agent.get_cls(args.agent_cls)

+ 3 - 3
tests/unit/test_arg_parser.py

@@ -41,9 +41,9 @@ options:
   --eval-note EVAL_NOTE
                         The note to add to the evaluation directory
   -l LLM_CONFIG, --llm-config LLM_CONFIG
-                        The group of llm settings, e.g. "llama3" for
-                        [llm.llama3] section in the toml file. Overrides model
-                        if both are provided.
+                        Replace default LLM ([llm] section in config.toml)
+                        config with the specified LLM config, e.g. "llama3"
+                        for [llm.llama3] section in config.toml
 """
 
     actual_lines = captured.out.strip().split('\n')