llm_config.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from dataclasses import dataclass, fields
  2. from openhands.core.config.config_utils import get_field_info
  3. LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key']
  4. @dataclass
  5. class LLMConfig:
  6. """Configuration for the LLM model.
  7. Attributes:
  8. model: The model to use.
  9. api_key: The API key to use.
  10. base_url: The base URL for the API. This is necessary for local LLMs. It is also used for Azure embeddings.
  11. api_version: The version of the API.
  12. embedding_model: The embedding model to use.
  13. embedding_base_url: The base URL for the embedding API.
  14. embedding_deployment_name: The name of the deployment for the embedding API. This is used for Azure OpenAI.
  15. aws_access_key_id: The AWS access key ID.
  16. aws_secret_access_key: The AWS secret access key.
  17. aws_region_name: The AWS region name.
  18. num_retries: The number of retries to attempt.
  19. retry_multiplier: The multiplier for the exponential backoff.
  20. retry_min_wait: The minimum time to wait between retries, in seconds. This is exponential backoff minimum. For models with very low limits, this can be set to 15-20.
  21. retry_max_wait: The maximum time to wait between retries, in seconds. This is exponential backoff maximum.
  22. timeout: The timeout for the API.
  23. max_message_chars: The approximate max number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated.
  24. temperature: The temperature for the API.
  25. top_p: The top p for the API.
  26. custom_llm_provider: The custom LLM provider to use. This is undocumented in openhands, and normally not used. It is documented on the litellm side.
  27. max_input_tokens: The maximum number of input tokens. Note that this is currently unused, and the value at runtime is actually the total tokens in OpenAI (e.g. 128,000 tokens for GPT-4).
  28. max_output_tokens: The maximum number of output tokens. This is sent to the LLM.
  29. input_cost_per_token: The cost per input token. This will available in logs for the user to check.
  30. output_cost_per_token: The cost per output token. This will available in logs for the user to check.
  31. ollama_base_url: The base URL for the OLLAMA API.
  32. drop_params: Drop any unmapped (unsupported) params without causing an exception.
  33. disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
  34. caching_prompt: Using the prompt caching feature provided by the LLM.
  35. log_completions: Whether to log LLM completions to the state.
  36. """
  37. model: str = 'gpt-4o'
  38. api_key: str | None = None
  39. base_url: str | None = None
  40. api_version: str | None = None
  41. embedding_model: str = 'local'
  42. embedding_base_url: str | None = None
  43. embedding_deployment_name: str | None = None
  44. aws_access_key_id: str | None = None
  45. aws_secret_access_key: str | None = None
  46. aws_region_name: str | None = None
  47. openrouter_site_url: str = 'https://docs.all-hands.dev/'
  48. openrouter_app_name: str = 'OpenHands'
  49. num_retries: int = 8
  50. retry_multiplier: float = 2
  51. retry_min_wait: int = 15
  52. retry_max_wait: int = 120
  53. timeout: int | None = None
  54. max_message_chars: int = 10_000 # maximum number of characters in an observation's content when sent to the llm
  55. temperature: float = 0.0
  56. top_p: float = 1.0
  57. custom_llm_provider: str | None = None
  58. max_input_tokens: int | None = None
  59. max_output_tokens: int | None = None
  60. input_cost_per_token: float | None = None
  61. output_cost_per_token: float | None = None
  62. ollama_base_url: str | None = None
  63. drop_params: bool = True
  64. disable_vision: bool | None = None
  65. caching_prompt: bool = False
  66. log_completions: bool = False
  67. def defaults_to_dict(self) -> dict:
  68. """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
  69. result = {}
  70. for f in fields(self):
  71. result[f.name] = get_field_info(f)
  72. return result
  73. def __str__(self):
  74. attr_str = []
  75. for f in fields(self):
  76. attr_name = f.name
  77. attr_value = getattr(self, f.name)
  78. if attr_name in LLM_SENSITIVE_FIELDS:
  79. attr_value = '******' if attr_value else None
  80. attr_str.append(f'{attr_name}={repr(attr_value)}')
  81. return f"LLMConfig({', '.join(attr_str)})"
  82. def __repr__(self):
  83. return self.__str__()
  84. def to_safe_dict(self):
  85. """Return a dict with the sensitive fields replaced with ******."""
  86. ret = self.__dict__.copy()
  87. for k, v in ret.items():
  88. if k in LLM_SENSITIVE_FIELDS:
  89. ret[k] = '******' if v else None
  90. return ret
  91. def set_missing_attributes(self):
  92. """Set any missing attributes to their default values."""
  93. for field_name, field_obj in self.__dataclass_fields__.items():
  94. if not hasattr(self, field_name):
  95. setattr(self, field_name, field_obj.default)