config.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import argparse
  2. import logging
  3. import os
  4. import pathlib
  5. import platform
  6. from dataclasses import dataclass, field, fields, is_dataclass
  7. from types import UnionType
  8. from typing import Any, ClassVar, get_args, get_origin
  9. import toml
  10. from dotenv import load_dotenv
  11. from opendevin.core.utils import Singleton
  12. logger = logging.getLogger(__name__)
  13. load_dotenv()
  14. @dataclass
  15. class LLMConfig(metaclass=Singleton):
  16. model: str = 'gpt-3.5-turbo'
  17. api_key: str | None = None
  18. base_url: str | None = None
  19. api_version: str | None = None
  20. embedding_model: str = 'local'
  21. embedding_base_url: str | None = None
  22. embedding_deployment_name: str | None = None
  23. aws_access_key_id: str | None = None
  24. aws_secret_access_key: str | None = None
  25. aws_region_name: str | None = None
  26. num_retries: int = 5
  27. retry_min_wait: int = 3
  28. retry_max_wait: int = 60
  29. timeout: int | None = None
  30. max_chars: int = 5_000_000 # fallback for token counting
  31. temperature: float = 0
  32. top_p: float = 0.5
  33. custom_llm_provider: str | None = None
  34. max_input_tokens: int | None = None
  35. max_output_tokens: int | None = None
  36. def defaults_to_dict(self) -> dict:
  37. """
  38. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  39. """
  40. dict = {}
  41. for f in fields(self):
  42. dict[f.name] = get_field_info(f)
  43. return dict
  44. @dataclass
  45. class AgentConfig(metaclass=Singleton):
  46. name: str = 'CodeActAgent'
  47. memory_enabled: bool = False
  48. memory_max_threads: int = 2
  49. def defaults_to_dict(self) -> dict:
  50. """
  51. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  52. """
  53. dict = {}
  54. for f in fields(self):
  55. dict[f.name] = get_field_info(f)
  56. return dict
  57. @dataclass
  58. class AppConfig(metaclass=Singleton):
  59. llm: LLMConfig = field(default_factory=LLMConfig)
  60. agent: AgentConfig = field(default_factory=AgentConfig)
  61. runtime: str = 'server'
  62. file_store: str = 'memory'
  63. file_store_path: str = '/tmp/file_store'
  64. workspace_base: str = os.path.join(os.getcwd(), 'workspace')
  65. workspace_mount_path: str | None = None
  66. workspace_mount_path_in_sandbox: str = '/workspace'
  67. workspace_mount_rewrite: str | None = None
  68. cache_dir: str = '/tmp/cache'
  69. sandbox_container_image: str = 'ghcr.io/opendevin/sandbox' + (
  70. f':{os.getenv("OPEN_DEVIN_BUILD_VERSION")}'
  71. if os.getenv('OPEN_DEVIN_BUILD_VERSION')
  72. else ':main'
  73. )
  74. run_as_devin: bool = True
  75. max_iterations: int = 100
  76. e2b_api_key: str = ''
  77. sandbox_type: str = 'ssh' # Can be 'ssh', 'exec', or 'e2b'
  78. use_host_network: bool = False
  79. ssh_hostname: str = 'localhost'
  80. disable_color: bool = False
  81. sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
  82. sandbox_timeout: int = 120
  83. github_token: str | None = None
  84. debug: bool = False
  85. defaults_dict: ClassVar[dict] = {}
  86. def __post_init__(self):
  87. """
  88. Post-initialization hook, called when the instance is created with only default values.
  89. """
  90. AppConfig.defaults_dict = self.defaults_to_dict()
  91. def defaults_to_dict(self) -> dict:
  92. """
  93. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  94. """
  95. dict = {}
  96. for f in fields(self):
  97. field_value = getattr(self, f.name)
  98. # dataclasses compute their defaults themselves
  99. if is_dataclass(type(field_value)):
  100. dict[f.name] = field_value.defaults_to_dict()
  101. else:
  102. dict[f.name] = get_field_info(f)
  103. return dict
  104. def get_field_info(field):
  105. """
  106. Extract information about a dataclass field: type, optional, and default.
  107. Args:
  108. field: The field to extract information from.
  109. Returns: A dict with the field's type, whether it's optional, and its default value.
  110. """
  111. field_type = field.type
  112. optional = False
  113. # for types like str | None, find the non-None type and set optional to True
  114. # this is useful for the frontend to know if a field is optional
  115. # and to show the correct type in the UI
  116. # Note: this only works for UnionTypes with None as one of the types
  117. if get_origin(field_type) is UnionType:
  118. types = get_args(field_type)
  119. non_none_arg = next((t for t in types if t is not type(None)), None)
  120. if non_none_arg is not None:
  121. field_type = non_none_arg
  122. optional = True
  123. # type name in a pretty format
  124. type_name = (
  125. field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
  126. )
  127. # default is always present
  128. default = field.default
  129. # return a schema with the useful info for frontend
  130. return {'type': type_name.lower(), 'optional': optional, 'default': default}
  131. def load_from_env(config: AppConfig, env_or_toml_dict: dict | os._Environ):
  132. """Reads the env-style vars and sets config attributes based on env vars or a config.toml dict.
  133. Compatibility with vars like LLM_BASE_URL, AGENT_MEMORY_ENABLED and others.
  134. Args:
  135. config: The AppConfig object to set attributes on.
  136. env_or_toml_dict: The environment variables or a config.toml dict.
  137. """
  138. def get_optional_type(union_type: UnionType) -> Any:
  139. """Returns the non-None type from an Union."""
  140. types = get_args(union_type)
  141. return next((t for t in types if t is not type(None)), None)
  142. # helper function to set attributes based on env vars
  143. def set_attr_from_env(sub_config: Any, prefix=''):
  144. """Set attributes of a config dataclass based on environment variables."""
  145. for field_name, field_type in sub_config.__annotations__.items():
  146. # compute the expected env var name from the prefix and field name
  147. # e.g. LLM_BASE_URL
  148. env_var_name = (prefix + field_name).upper()
  149. if is_dataclass(field_type):
  150. # nested dataclass
  151. nested_sub_config = getattr(sub_config, field_name)
  152. # the agent field: the env var for agent.name is just 'AGENT'
  153. if field_name == 'agent' and 'AGENT' in env_or_toml_dict:
  154. setattr(nested_sub_config, 'name', env_or_toml_dict[env_var_name])
  155. set_attr_from_env(nested_sub_config, prefix=field_name + '_')
  156. elif env_var_name in env_or_toml_dict:
  157. # convert the env var to the correct type and set it
  158. value = env_or_toml_dict[env_var_name]
  159. try:
  160. # if it's an optional type, get the non-None type
  161. if get_origin(field_type) is UnionType:
  162. field_type = get_optional_type(field_type)
  163. # Attempt to cast the env var to type hinted in the dataclass
  164. if field_type is bool:
  165. cast_value = str(value).lower() in ['true', '1']
  166. else:
  167. cast_value = field_type(value)
  168. setattr(sub_config, field_name, cast_value)
  169. except (ValueError, TypeError):
  170. logger.error(
  171. f'Error setting env var {env_var_name}={value}: check that the value is of the right type'
  172. )
  173. # Start processing from the root of the config object
  174. set_attr_from_env(config)
  175. def load_from_toml(config: AppConfig, toml_file: str = 'config.toml'):
  176. """Load the config from the toml file. Supports both styles of config vars.
  177. Args:
  178. config: The AppConfig object to update attributes of.
  179. """
  180. # try to read the config.toml file into the config object
  181. toml_config = {}
  182. try:
  183. with open(toml_file, 'r', encoding='utf-8') as toml_contents:
  184. toml_config = toml.load(toml_contents)
  185. except FileNotFoundError:
  186. # the file is optional, we don't need to do anything
  187. return
  188. except toml.TomlDecodeError:
  189. logger.warning(
  190. 'Cannot parse config from toml, toml values have not been applied.',
  191. exc_info=False,
  192. )
  193. return
  194. # if there was an exception or core is not in the toml, try to use the old-style toml
  195. if 'core' not in toml_config:
  196. # re-use the env loader to set the config from env-style vars
  197. load_from_env(config, toml_config)
  198. return
  199. core_config = toml_config['core']
  200. try:
  201. # set llm config from the toml file
  202. llm_config = config.llm
  203. if 'llm' in toml_config:
  204. llm_config = LLMConfig(**toml_config['llm'])
  205. # set agent config from the toml file
  206. agent_config = config.agent
  207. if 'agent' in toml_config:
  208. agent_config = AgentConfig(**toml_config['agent'])
  209. # update the config object with the new values
  210. config = AppConfig(llm=llm_config, agent=agent_config, **core_config)
  211. except (TypeError, KeyError):
  212. logger.warning(
  213. 'Cannot parse config from toml, toml values have not been applied.',
  214. exc_info=False,
  215. )
  216. def finalize_config(config: AppConfig):
  217. """
  218. More tweaks to the config after it's been loaded.
  219. """
  220. # Set workspace_mount_path if not set by the user
  221. if config.workspace_mount_path is None:
  222. config.workspace_mount_path = os.path.abspath(config.workspace_base)
  223. config.workspace_base = os.path.abspath(config.workspace_base)
  224. # In local there is no sandbox, the workspace will have the same pwd as the host
  225. if config.sandbox_type == 'local':
  226. config.workspace_mount_path_in_sandbox = config.workspace_mount_path
  227. if config.workspace_mount_rewrite: # and not config.workspace_mount_path:
  228. # TODO why do we need to check if workspace_mount_path is None?
  229. base = config.workspace_base or os.getcwd()
  230. parts = config.workspace_mount_rewrite.split(':')
  231. config.workspace_mount_path = base.replace(parts[0], parts[1])
  232. if config.llm.embedding_base_url is None:
  233. config.llm.embedding_base_url = config.llm.base_url
  234. if config.use_host_network and platform.system() == 'Darwin':
  235. logger.warning(
  236. 'Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. '
  237. 'See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information.'
  238. )
  239. # make sure cache dir exists
  240. if config.cache_dir:
  241. pathlib.Path(config.cache_dir).mkdir(parents=True, exist_ok=True)
  242. config = AppConfig()
  243. load_from_toml(config)
  244. load_from_env(config, os.environ)
  245. finalize_config(config)
  246. # Utility function for command line --group argument
  247. def get_llm_config_arg(llm_config_arg: str):
  248. """
  249. Get a group of llm settings from the config file.
  250. """
  251. # keep only the name, just in case
  252. llm_config_arg = llm_config_arg.strip('[]')
  253. logger.info(f'Loading llm config from {llm_config_arg}')
  254. # load the toml file
  255. try:
  256. with open('config.toml', 'r', encoding='utf-8') as toml_file:
  257. toml_config = toml.load(toml_file)
  258. except FileNotFoundError:
  259. return None
  260. except toml.TomlDecodeError as e:
  261. logger.error(f'Cannot parse llm group from {llm_config_arg}. Exception: {e}')
  262. return None
  263. # update the llm config with the specified section
  264. if llm_config_arg in toml_config:
  265. return LLMConfig(**toml_config[llm_config_arg])
  266. logger.debug(f'Loading from toml failed for {llm_config_arg}')
  267. return None
  268. # Command line arguments
  269. def get_parser():
  270. """
  271. Get the parser for the command line arguments.
  272. """
  273. parser = argparse.ArgumentParser(description='Run an agent with a specific task')
  274. parser.add_argument(
  275. '-d',
  276. '--directory',
  277. type=str,
  278. help='The working directory for the agent',
  279. )
  280. parser.add_argument(
  281. '-t', '--task', type=str, default='', help='The task for the agent to perform'
  282. )
  283. parser.add_argument(
  284. '-f',
  285. '--file',
  286. type=str,
  287. help='Path to a file containing the task. Overrides -t if both are provided.',
  288. )
  289. parser.add_argument(
  290. '-c',
  291. '--agent-cls',
  292. default=config.agent.name,
  293. type=str,
  294. help='The agent class to use',
  295. )
  296. parser.add_argument(
  297. '-m',
  298. '--model-name',
  299. default=config.llm.model,
  300. type=str,
  301. help='The (litellm) model name to use',
  302. )
  303. parser.add_argument(
  304. '-i',
  305. '--max-iterations',
  306. default=config.max_iterations,
  307. type=int,
  308. help='The maximum number of iterations to run the agent',
  309. )
  310. parser.add_argument(
  311. '-n',
  312. '--max-chars',
  313. default=config.llm.max_chars,
  314. type=int,
  315. help='The maximum number of characters to send to and receive from LLM per task',
  316. )
  317. parser.add_argument(
  318. '--eval-output-dir',
  319. default='evaluation/evaluation_outputs/outputs',
  320. type=str,
  321. help='The directory to save evaluation output',
  322. )
  323. parser.add_argument(
  324. '--eval-n-limit',
  325. default=None,
  326. type=int,
  327. help='The number of instances to evaluate',
  328. )
  329. parser.add_argument(
  330. '--eval-num-workers',
  331. default=4,
  332. type=int,
  333. help='The number of workers to use for evaluation',
  334. )
  335. parser.add_argument(
  336. '--eval-note',
  337. default=None,
  338. type=str,
  339. help='The note to add to the evaluation directory',
  340. )
  341. parser.add_argument(
  342. '-l',
  343. '--llm-config',
  344. default=None,
  345. type=str,
  346. help='The group of llm settings, e.g. a [llama3] section in the toml file. Overrides model if both are provided.',
  347. )
  348. return parser
  349. def parse_arguments():
  350. """
  351. Parse the command line arguments.
  352. """
  353. parser = get_parser()
  354. args, _ = parser.parse_known_args()
  355. if args.directory:
  356. config.workspace_base = os.path.abspath(args.directory)
  357. print(f'Setting workspace base to {config.workspace_base}')
  358. return args
  359. args = parse_arguments()