llm_config.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. from dataclasses import dataclass, fields
  3. from typing import Optional
  4. from openhands.core.config.config_utils import get_field_info
  5. from openhands.core.logger import LOG_DIR
  6. LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
  7. @dataclass
  8. class LLMConfig:
  9. """Configuration for the LLM model.
  10. Attributes:
  11. model: The model to use.
  12. api_key: The API key to use.
  13. base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
  14. api_version: The version of the API.
  15. embedding_model: The embedding model to use.
  16. embedding_base_url: The base URL for the embedding API.
  17. embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
  18. aws_access_key_id: The AWS access key ID.
  19. aws_secret_access_key: The AWS secret access key.
  20. aws_region_name: The AWS region name.
  21. num_retries: The number of retries to attempt.
  22. retry_multiplier: The multiplier for the exponential backoff.
  23. retry_min_wait: The minimum time to wait between retries, in seconds. This is exponential backoff minimum. For models with very low limits, this can be set to 15-20.
  24. retry_max_wait: The maximum time to wait between retries, in seconds. This is exponential backoff maximum.
  25. timeout: The timeout for the API.
  26. max_message_chars: The approximate max number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated.
  27. temperature: The temperature for the API.
  28. top_p: The top p for the API.
  29. custom_llm_provider: The custom LLM provider to use. This is undocumented in openhands, and normally not used. It is documented on the litellm side.
  30. max_input_tokens: The maximum number of input tokens. Note that this is currently unused, and the value at runtime is actually the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).
  31. max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
  32. input_cost_per_token: The cost per input token. This will available in logs for the user to check.
  33. output_cost_per_token: The cost per output token. This will available in logs for the user to check.
  34. ollama_base_url: The base URL for the OLLAMA API.
  35. drop_params: Drop any unmapped (unsupported) params without causing an exception.
  36. modify_params: Modify params allows litellm to do transformations like adding a default message, when a message is empty.
  37. disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
  38. caching_prompt: Use the prompt caching feature if provided by the LLM and supported by the provider.
  39. log_completions: Whether to log LLM completions to the state.
  40. log_completions_folder: The folder to log LLM completions to. Required if log_completions is True.
  41. draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985).
  42. custom_tokenizer: A custom tokenizer to use for token counting.
  43. """
  44. model: str = 'claude-3-5-sonnet-20241022'
  45. api_key: str | None = None
  46. base_url: str | None = None
  47. api_version: str | None = None
  48. embedding_model: str = 'local'
  49. embedding_base_url: str | None = None
  50. embedding_deployment_name: str | None = None
  51. aws_access_key_id: str | None = None
  52. aws_secret_access_key: str | None = None
  53. aws_region_name: str | None = None
  54. openrouter_site_url: str = 'https://docs.all-hands.dev/'
  55. openrouter_app_name: str = 'OpenHands'
  56. num_retries: int = 8
  57. retry_multiplier: float = 2
  58. retry_min_wait: int = 15
  59. retry_max_wait: int = 120
  60. timeout: int | None = None
  61. max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm
  62. temperature: float = 0.0
  63. top_p: float = 1.0
  64. custom_llm_provider: str | None = None
  65. max_input_tokens: int | None = None
  66. max_output_tokens: int | None = None
  67. input_cost_per_token: float | None = None
  68. output_cost_per_token: float | None = None
  69. ollama_base_url: str | None = None
  70. # This setting can be sent in each call to litellm
  71. drop_params: bool = True
  72. # Note: this setting is actually global, unlike drop_params
  73. modify_params: bool = True
  74. disable_vision: bool | None = None
  75. caching_prompt: bool = True
  76. log_completions: bool = False
  77. log_completions_folder: str = os.path.join(LOG_DIR, 'completions')
  78. draft_editor: Optional['LLMConfig'] = None
  79. custom_tokenizer: str | None = None
  80. def defaults_to_dict(self) -> dict:
  81. """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
  82. result = {}
  83. for f in fields(self):
  84. result[f.name] = get_field_info(f)
  85. return result
  86. def __post_init__(self):
  87. """
  88. Post-initialization hook to assign OpenRouter-related variables to environment variables.
  89. This ensures that these values are accessible to litellm at runtime.
  90. """
  91. # Assign OpenRouter-specific variables to environment variables
  92. if self.openrouter_site_url:
  93. os.environ['OR_SITE_URL'] = self.openrouter_site_url
  94. if self.openrouter_app_name:
  95. os.environ['OR_APP_NAME'] = self.openrouter_app_name
  96. def __str__(self):
  97. attr_str = []
  98. for f in fields(self):
  99. attr_name = f.name
  100. attr_value = getattr(self, f.name)
  101. if attr_name in LLM_SENSITIVE_FIELDS:
  102. attr_value = '******' if attr_value else None
  103. attr_str.append(f'{attr_name}={repr(attr_value)}')
  104. return f"LLMConfig({', '.join(attr_str)})"
  105. def __repr__(self):
  106. return self.__str__()
  107. def to_safe_dict(self):
  108. """Return a dict with the sensitive fields replaced with ******."""
  109. ret = self.__dict__.copy()
  110. for k, v in ret.items():
  111. if k in LLM_SENSITIVE_FIELDS:
  112. ret[k] = '******' if v else None
  113. elif isinstance(v, LLMConfig):
  114. ret[k] = v.to_safe_dict()
  115. return ret
  116. @classmethod
  117. def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
  118. """Create an LLMConfig object from a dictionary.
  119. This function is used to create an LLMConfig object from a dictionary,
  120. with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
  121. """
  122. args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
  123. if 'draft_editor' in llm_config_dict:
  124. draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
  125. args['draft_editor'] = draft_editor_config
  126. return cls(**args)