config.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import argparse
  2. import logging
  3. import os
  4. import pathlib
  5. import platform
  6. from dataclasses import dataclass, field, fields, is_dataclass
  7. from types import UnionType
  8. from typing import Any, ClassVar, get_args, get_origin
  9. import toml
  10. from dotenv import load_dotenv
  11. from opendevin.core.utils import Singleton
  12. logger = logging.getLogger(__name__)
  13. load_dotenv()
  14. @dataclass
  15. class LLMConfig(metaclass=Singleton):
  16. model: str = 'gpt-3.5-turbo'
  17. api_key: str | None = None
  18. base_url: str | None = None
  19. api_version: str | None = None
  20. embedding_model: str = 'local'
  21. embedding_base_url: str | None = None
  22. embedding_deployment_name: str | None = None
  23. aws_access_key_id: str | None = None
  24. aws_secret_access_key: str | None = None
  25. aws_region_name: str | None = None
  26. num_retries: int = 5
  27. retry_min_wait: int = 3
  28. retry_max_wait: int = 60
  29. timeout: int | None = None
  30. max_chars: int = 5_000_000 # fallback for token counting
  31. temperature: float = 0
  32. top_p: float = 0.5
  33. custom_llm_provider: str | None = None
  34. max_input_tokens: int | None = None
  35. max_output_tokens: int | None = None
  36. def defaults_to_dict(self) -> dict:
  37. """
  38. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  39. """
  40. dict = {}
  41. for f in fields(self):
  42. dict[f.name] = get_field_info(f)
  43. return dict
  44. @dataclass
  45. class AgentConfig(metaclass=Singleton):
  46. name: str = 'CodeActAgent'
  47. memory_enabled: bool = False
  48. memory_max_threads: int = 2
  49. def defaults_to_dict(self) -> dict:
  50. """
  51. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  52. """
  53. dict = {}
  54. for f in fields(self):
  55. dict[f.name] = get_field_info(f)
  56. return dict
  57. @dataclass
  58. class AppConfig(metaclass=Singleton):
  59. llm: LLMConfig = field(default_factory=LLMConfig)
  60. agent: AgentConfig = field(default_factory=AgentConfig)
  61. workspace_base: str = os.getcwd()
  62. workspace_mount_path: str = os.getcwd()
  63. workspace_mount_path_in_sandbox: str = '/workspace'
  64. workspace_mount_rewrite: str | None = None
  65. cache_dir: str = '/tmp/cache'
  66. sandbox_container_image: str = 'ghcr.io/opendevin/sandbox' + (
  67. f':{os.getenv("OPEN_DEVIN_BUILD_VERSION")}'
  68. if os.getenv('OPEN_DEVIN_BUILD_VERSION')
  69. else ':main'
  70. )
  71. run_as_devin: bool = True
  72. max_iterations: int = 100
  73. e2b_api_key: str = ''
  74. sandbox_type: str = 'ssh' # Can be 'ssh', 'exec', or 'e2b'
  75. use_host_network: bool = False
  76. ssh_hostname: str = 'localhost'
  77. disable_color: bool = False
  78. sandbox_user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000
  79. sandbox_timeout: int = 120
  80. github_token: str | None = None
  81. debug: bool = False
  82. defaults_dict: ClassVar[dict] = {}
  83. def __post_init__(self):
  84. """
  85. Post-initialization hook, called when the instance is created with only default values.
  86. """
  87. AppConfig.defaults_dict = self.defaults_to_dict()
  88. def defaults_to_dict(self) -> dict:
  89. """
  90. Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.
  91. """
  92. dict = {}
  93. for f in fields(self):
  94. field_value = getattr(self, f.name)
  95. # dataclasses compute their defaults themselves
  96. if is_dataclass(type(field_value)):
  97. dict[f.name] = field_value.defaults_to_dict()
  98. else:
  99. dict[f.name] = get_field_info(f)
  100. return dict
  101. def get_field_info(field):
  102. """
  103. Extract information about a dataclass field: type, optional, and default.
  104. Args:
  105. field: The field to extract information from.
  106. Returns: A dict with the field's type, whether it's optional, and its default value.
  107. """
  108. field_type = field.type
  109. optional = False
  110. # for types like str | None, find the non-None type and set optional to True
  111. # this is useful for the frontend to know if a field is optional
  112. # and to show the correct type in the UI
  113. # Note: this only works for UnionTypes with None as one of the types
  114. if get_origin(field_type) is UnionType:
  115. types = get_args(field_type)
  116. non_none_arg = next((t for t in types if t is not type(None)), None)
  117. if non_none_arg is not None:
  118. field_type = non_none_arg
  119. optional = True
  120. # type name in a pretty format
  121. type_name = (
  122. field_type.__name__ if hasattr(field_type, '__name__') else str(field_type)
  123. )
  124. # default is always present
  125. default = field.default
  126. # return a schema with the useful info for frontend
  127. return {'type': type_name.lower(), 'optional': optional, 'default': default}
  128. def load_from_env(config: AppConfig, env_or_toml_dict: dict | os._Environ):
  129. """Reads the env-style vars and sets config attributes based on env vars or a config.toml dict.
  130. Compatibility with vars like LLM_BASE_URL, AGENT_MEMORY_ENABLED and others.
  131. Args:
  132. config: The AppConfig object to set attributes on.
  133. env_or_toml_dict: The environment variables or a config.toml dict.
  134. """
  135. def get_optional_type(union_type: UnionType) -> Any:
  136. """Returns the non-None type from an Union."""
  137. types = get_args(union_type)
  138. return next((t for t in types if t is not type(None)), None)
  139. # helper function to set attributes based on env vars
  140. def set_attr_from_env(sub_config: Any, prefix=''):
  141. """Set attributes of a config dataclass based on environment variables."""
  142. for field_name, field_type in sub_config.__annotations__.items():
  143. # compute the expected env var name from the prefix and field name
  144. # e.g. LLM_BASE_URL
  145. env_var_name = (prefix + field_name).upper()
  146. if is_dataclass(field_type):
  147. # nested dataclass
  148. nested_sub_config = getattr(sub_config, field_name)
  149. # the agent field: the env var for agent.name is just 'AGENT'
  150. if field_name == 'agent' and 'AGENT' in env_or_toml_dict:
  151. setattr(nested_sub_config, 'name', env_or_toml_dict[env_var_name])
  152. set_attr_from_env(nested_sub_config, prefix=field_name + '_')
  153. elif env_var_name in env_or_toml_dict:
  154. # convert the env var to the correct type and set it
  155. value = env_or_toml_dict[env_var_name]
  156. try:
  157. # if it's an optional type, get the non-None type
  158. if get_origin(field_type) is UnionType:
  159. field_type = get_optional_type(field_type)
  160. # Attempt to cast the env var to type hinted in the dataclass
  161. if field_type is bool:
  162. cast_value = str(value).lower() in ['true', '1']
  163. else:
  164. cast_value = field_type(value)
  165. setattr(sub_config, field_name, cast_value)
  166. except (ValueError, TypeError):
  167. logger.error(
  168. f'Error setting env var {env_var_name}={value}: check that the value is of the right type'
  169. )
  170. # Start processing from the root of the config object
  171. set_attr_from_env(config)
  172. def load_from_toml(config: AppConfig, toml_file: str = 'config.toml'):
  173. """Load the config from the toml file. Supports both styles of config vars.
  174. Args:
  175. config: The AppConfig object to update attributes of.
  176. """
  177. # try to read the config.toml file into the config object
  178. toml_config = {}
  179. try:
  180. with open(toml_file, 'r', encoding='utf-8') as toml_contents:
  181. toml_config = toml.load(toml_contents)
  182. except FileNotFoundError:
  183. # the file is optional, we don't need to do anything
  184. return
  185. except toml.TomlDecodeError:
  186. logger.warning(
  187. 'Cannot parse config from toml, toml values have not been applied.',
  188. exc_info=False,
  189. )
  190. return
  191. # if there was an exception or core is not in the toml, try to use the old-style toml
  192. if 'core' not in toml_config:
  193. # re-use the env loader to set the config from env-style vars
  194. load_from_env(config, toml_config)
  195. return
  196. core_config = toml_config['core']
  197. try:
  198. # set llm config from the toml file
  199. llm_config = config.llm
  200. if 'llm' in toml_config:
  201. llm_config = LLMConfig(**toml_config['llm'])
  202. # set agent config from the toml file
  203. agent_config = config.agent
  204. if 'agent' in toml_config:
  205. agent_config = AgentConfig(**toml_config['agent'])
  206. # update the config object with the new values
  207. config = AppConfig(llm=llm_config, agent=agent_config, **core_config)
  208. except (TypeError, KeyError):
  209. logger.warning(
  210. 'Cannot parse config from toml, toml values have not been applied.',
  211. exc_info=False,
  212. )
  213. def finalize_config(config: AppConfig):
  214. """
  215. More tweaks to the config after it's been loaded.
  216. """
  217. # In local there is no sandbox, the workspace will have the same pwd as the host
  218. if config.sandbox_type == 'local':
  219. config.workspace_mount_path_in_sandbox = config.workspace_mount_path
  220. if config.workspace_mount_rewrite: # and not config.workspace_mount_path:
  221. # TODO why do we need to check if workspace_mount_path is None?
  222. base = config.workspace_base or os.getcwd()
  223. parts = config.workspace_mount_rewrite.split(':')
  224. config.workspace_mount_path = base.replace(parts[0], parts[1])
  225. if config.llm.embedding_base_url is None:
  226. config.llm.embedding_base_url = config.llm.base_url
  227. if config.use_host_network and platform.system() == 'Darwin':
  228. logger.warning(
  229. 'Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. '
  230. 'See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information.'
  231. )
  232. # make sure cache dir exists
  233. if config.cache_dir:
  234. pathlib.Path(config.cache_dir).mkdir(parents=True, exist_ok=True)
  235. config = AppConfig()
  236. load_from_toml(config)
  237. load_from_env(config, os.environ)
  238. finalize_config(config)
  239. # Utility function for command line --group argument
  240. def get_llm_config_arg(llm_config_arg: str):
  241. """
  242. Get a group of llm settings from the config file.
  243. """
  244. # keep only the name, just in case
  245. llm_config_arg = llm_config_arg.strip('[]')
  246. logger.info(f'Loading llm config from {llm_config_arg}')
  247. # load the toml file
  248. try:
  249. with open('config.toml', 'r', encoding='utf-8') as toml_file:
  250. toml_config = toml.load(toml_file)
  251. except FileNotFoundError:
  252. return None
  253. except toml.TomlDecodeError as e:
  254. logger.error(f'Cannot parse llm group from {llm_config_arg}. Exception: {e}')
  255. return None
  256. # update the llm config with the specified section
  257. if llm_config_arg in toml_config:
  258. return LLMConfig(**toml_config[llm_config_arg])
  259. logger.debug(f'Loading from toml failed for {llm_config_arg}')
  260. return None
  261. # Command line arguments
  262. def get_parser():
  263. """
  264. Get the parser for the command line arguments.
  265. """
  266. parser = argparse.ArgumentParser(description='Run an agent with a specific task')
  267. parser.add_argument(
  268. '-d',
  269. '--directory',
  270. type=str,
  271. help='The working directory for the agent',
  272. )
  273. parser.add_argument(
  274. '-t', '--task', type=str, default='', help='The task for the agent to perform'
  275. )
  276. parser.add_argument(
  277. '-f',
  278. '--file',
  279. type=str,
  280. help='Path to a file containing the task. Overrides -t if both are provided.',
  281. )
  282. parser.add_argument(
  283. '-c',
  284. '--agent-cls',
  285. default=config.agent.name,
  286. type=str,
  287. help='The agent class to use',
  288. )
  289. parser.add_argument(
  290. '-m',
  291. '--model-name',
  292. default=config.llm.model,
  293. type=str,
  294. help='The (litellm) model name to use',
  295. )
  296. parser.add_argument(
  297. '-i',
  298. '--max-iterations',
  299. default=config.max_iterations,
  300. type=int,
  301. help='The maximum number of iterations to run the agent',
  302. )
  303. parser.add_argument(
  304. '-n',
  305. '--max-chars',
  306. default=config.llm.max_chars,
  307. type=int,
  308. help='The maximum number of characters to send to and receive from LLM per task',
  309. )
  310. parser.add_argument(
  311. '-l',
  312. '--llm-config',
  313. default=None,
  314. type=str,
  315. help='The group of llm settings, e.g. a [llama3] section in the toml file. Overrides model if both are provided.',
  316. )
  317. return parser
  318. def parse_arguments():
  319. """
  320. Parse the command line arguments.
  321. """
  322. parser = get_parser()
  323. args, _ = parser.parse_known_args()
  324. if args.directory:
  325. config.workspace_base = os.path.abspath(args.directory)
  326. print(f'Setting workspace base to {config.workspace_base}')
  327. return args
  328. args = parse_arguments()