llm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import warnings
  2. from functools import partial
  3. with warnings.catch_warnings():
  4. warnings.simplefilter('ignore')
  5. import litellm
  6. from litellm import completion as litellm_completion
  7. from litellm import completion_cost as litellm_completion_cost
  8. from litellm.exceptions import (
  9. APIConnectionError,
  10. RateLimitError,
  11. ServiceUnavailableError,
  12. )
  13. from litellm.types.utils import CostPerToken
  14. from tenacity import (
  15. retry,
  16. retry_if_exception_type,
  17. stop_after_attempt,
  18. wait_random_exponential,
  19. )
  20. from opendevin.core.config import config
  21. from opendevin.core.logger import llm_prompt_logger, llm_response_logger
  22. from opendevin.core.logger import opendevin_logger as logger
  23. from opendevin.core.metrics import Metrics
  24. __all__ = ['LLM']
  25. message_separator = '\n\n----------\n\n'
  26. class LLM:
  27. """
  28. The LLM class represents a Language Model instance.
  29. Attributes:
  30. model_name (str): The name of the language model.
  31. api_key (str): The API key for accessing the language model.
  32. base_url (str): The base URL for the language model API.
  33. api_version (str): The version of the API to use.
  34. max_input_tokens (int): The maximum number of tokens to send to the LLM per task.
  35. max_output_tokens (int): The maximum number of tokens to receive from the LLM per task.
  36. llm_timeout (int): The maximum time to wait for a response in seconds.
  37. custom_llm_provider (str): A custom LLM provider.
  38. """
  39. def __init__(
  40. self,
  41. model=None,
  42. api_key=None,
  43. base_url=None,
  44. api_version=None,
  45. num_retries=None,
  46. retry_min_wait=None,
  47. retry_max_wait=None,
  48. llm_timeout=None,
  49. llm_temperature=None,
  50. llm_top_p=None,
  51. custom_llm_provider=None,
  52. max_input_tokens=None,
  53. max_output_tokens=None,
  54. llm_config=None,
  55. metrics=None,
  56. cost_metric_supported=True,
  57. ):
  58. """
  59. Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  60. Passing simple parameters always overrides config.
  61. Args:
  62. model (str, optional): The name of the language model. Defaults to LLM_MODEL.
  63. api_key (str, optional): The API key for accessing the language model. Defaults to LLM_API_KEY.
  64. base_url (str, optional): The base URL for the language model API. Defaults to LLM_BASE_URL. Not necessary for OpenAI.
  65. api_version (str, optional): The version of the API to use. Defaults to LLM_API_VERSION. Not necessary for OpenAI.
  66. num_retries (int, optional): The number of retries for API calls. Defaults to LLM_NUM_RETRIES.
  67. retry_min_wait (int, optional): The minimum time to wait between retries in seconds. Defaults to LLM_RETRY_MIN_TIME.
  68. retry_max_wait (int, optional): The maximum time to wait between retries in seconds. Defaults to LLM_RETRY_MAX_TIME.
  69. max_input_tokens (int, optional): The maximum number of tokens to send to the LLM per task. Defaults to LLM_MAX_INPUT_TOKENS.
  70. max_output_tokens (int, optional): The maximum number of tokens to receive from the LLM per task. Defaults to LLM_MAX_OUTPUT_TOKENS.
  71. custom_llm_provider (str, optional): A custom LLM provider. Defaults to LLM_CUSTOM_LLM_PROVIDER.
  72. llm_timeout (int, optional): The maximum time to wait for a response in seconds. Defaults to LLM_TIMEOUT.
  73. llm_temperature (float, optional): The temperature for LLM sampling. Defaults to LLM_TEMPERATURE.
  74. metrics (Metrics, optional): The metrics object to use. Defaults to None.
  75. cost_metric_supported (bool, optional): Whether the cost metric is supported. Defaults to True.
  76. """
  77. if llm_config is None:
  78. llm_config = config.llm
  79. model = model if model is not None else llm_config.model
  80. api_key = api_key if api_key is not None else llm_config.api_key
  81. base_url = base_url if base_url is not None else llm_config.base_url
  82. api_version = api_version if api_version is not None else llm_config.api_version
  83. num_retries = num_retries if num_retries is not None else llm_config.num_retries
  84. retry_min_wait = (
  85. retry_min_wait if retry_min_wait is not None else llm_config.retry_min_wait
  86. )
  87. retry_max_wait = (
  88. retry_max_wait if retry_max_wait is not None else llm_config.retry_max_wait
  89. )
  90. llm_timeout = llm_timeout if llm_timeout is not None else llm_config.timeout
  91. llm_temperature = (
  92. llm_temperature if llm_temperature is not None else llm_config.temperature
  93. )
  94. llm_top_p = llm_top_p if llm_top_p is not None else llm_config.top_p
  95. custom_llm_provider = (
  96. custom_llm_provider
  97. if custom_llm_provider is not None
  98. else llm_config.custom_llm_provider
  99. )
  100. max_input_tokens = (
  101. max_input_tokens
  102. if max_input_tokens is not None
  103. else llm_config.max_input_tokens
  104. )
  105. max_output_tokens = (
  106. max_output_tokens
  107. if max_output_tokens is not None
  108. else llm_config.max_output_tokens
  109. )
  110. metrics = metrics if metrics is not None else Metrics()
  111. logger.info(f'Initializing LLM with model: {model}')
  112. self.model_name = model
  113. self.api_key = api_key
  114. self.base_url = base_url
  115. self.api_version = api_version
  116. self.max_input_tokens = max_input_tokens
  117. self.max_output_tokens = max_output_tokens
  118. self.llm_timeout = llm_timeout
  119. self.custom_llm_provider = custom_llm_provider
  120. self.metrics = metrics
  121. self.cost_metric_supported = cost_metric_supported
  122. # litellm actually uses base Exception here for unknown model
  123. self.model_info = None
  124. try:
  125. if not self.model_name.startswith('openrouter'):
  126. self.model_info = litellm.get_model_info(self.model_name.split(':')[0])
  127. else:
  128. self.model_info = litellm.get_model_info(self.model_name)
  129. # noinspection PyBroadException
  130. except Exception:
  131. logger.warning(f'Could not get model info for {self.model_name}')
  132. if self.max_input_tokens is None:
  133. if self.model_info is not None and 'max_input_tokens' in self.model_info:
  134. self.max_input_tokens = self.model_info['max_input_tokens']
  135. else:
  136. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  137. self.max_input_tokens = 4096
  138. if self.max_output_tokens is None:
  139. if self.model_info is not None and 'max_output_tokens' in self.model_info:
  140. self.max_output_tokens = self.model_info['max_output_tokens']
  141. else:
  142. # Enough tokens for most output actions, and not too many for a bad llm to get carried away responding
  143. # with thousands of unwanted tokens
  144. self.max_output_tokens = 1024
  145. self._completion = partial(
  146. litellm_completion,
  147. model=self.model_name,
  148. api_key=self.api_key,
  149. base_url=self.base_url,
  150. api_version=self.api_version,
  151. custom_llm_provider=custom_llm_provider,
  152. max_tokens=self.max_output_tokens,
  153. timeout=self.llm_timeout,
  154. temperature=llm_temperature,
  155. top_p=llm_top_p,
  156. )
  157. completion_unwrapped = self._completion
  158. def attempt_on_error(retry_state):
  159. logger.error(
  160. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  161. exc_info=False,
  162. )
  163. return True
  164. @retry(
  165. reraise=True,
  166. stop=stop_after_attempt(num_retries),
  167. wait=wait_random_exponential(min=retry_min_wait, max=retry_max_wait),
  168. retry=retry_if_exception_type(
  169. (RateLimitError, APIConnectionError, ServiceUnavailableError)
  170. ),
  171. after=attempt_on_error,
  172. )
  173. def wrapper(*args, **kwargs):
  174. if 'messages' in kwargs:
  175. messages = kwargs['messages']
  176. else:
  177. messages = args[1]
  178. debug_message = ''
  179. for message in messages:
  180. debug_message += message_separator + message['content']
  181. llm_prompt_logger.debug(debug_message)
  182. resp = completion_unwrapped(*args, **kwargs)
  183. message_back = resp['choices'][0]['message']['content']
  184. llm_response_logger.debug(message_back)
  185. return resp
  186. self._completion = wrapper # type: ignore
  187. @property
  188. def completion(self):
  189. """
  190. Decorator for the litellm completion function.
  191. """
  192. return self._completion
  193. def do_completion(self, *args, **kwargs):
  194. """
  195. Wrapper for the litellm completion function.
  196. Check the complete documentation at https://litellm.vercel.app/docs/completion
  197. """
  198. resp = self._completion(*args, **kwargs)
  199. self.post_completion(resp)
  200. return resp
  201. def post_completion(self, response: str) -> None:
  202. """
  203. Post-process the completion response.
  204. """
  205. try:
  206. cur_cost = self.completion_cost(response)
  207. except Exception:
  208. cur_cost = 0
  209. if self.cost_metric_supported:
  210. logger.info(
  211. 'Cost: %.2f USD | Accumulated Cost: %.2f USD',
  212. cur_cost,
  213. self.metrics.accumulated_cost,
  214. )
  215. def get_token_count(self, messages):
  216. """
  217. Get the number of tokens in a list of messages.
  218. Args:
  219. messages (list): A list of messages.
  220. Returns:
  221. int: The number of tokens.
  222. """
  223. return litellm.token_counter(model=self.model_name, messages=messages)
  224. def is_local(self):
  225. """
  226. Determines if the system is using a locally running LLM.
  227. Returns:
  228. boolean: True if executing a local model.
  229. """
  230. if self.base_url is not None:
  231. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  232. if substring in self.base_url:
  233. return True
  234. elif self.model_name is not None:
  235. if self.model_name.startswith('ollama'):
  236. return True
  237. return False
  238. def completion_cost(self, response):
  239. """
  240. Calculate the cost of a completion response based on the model. Local models are treated as free.
  241. Add the current cost into total cost in metrics.
  242. Args:
  243. response (list): A response from a model invocation.
  244. Returns:
  245. number: The cost of the response.
  246. """
  247. if not self.cost_metric_supported:
  248. return 0.0
  249. extra_kwargs = {}
  250. if (
  251. config.llm.input_cost_per_token is not None
  252. and config.llm.output_cost_per_token is not None
  253. ):
  254. cost_per_token = CostPerToken(
  255. input_cost_per_token=config.llm.input_cost_per_token,
  256. output_cost_per_token=config.llm.output_cost_per_token,
  257. )
  258. logger.info(f'Using custom cost per token: {cost_per_token}')
  259. extra_kwargs['custom_cost_per_token'] = cost_per_token
  260. if not self.is_local():
  261. try:
  262. cost = litellm_completion_cost(
  263. completion_response=response, **extra_kwargs
  264. )
  265. self.metrics.add_cost(cost)
  266. return cost
  267. except Exception:
  268. self.cost_metric_supported = False
  269. logger.warning('Cost calculation not supported for this model.')
  270. return 0.0
  271. def __str__(self):
  272. if self.api_version:
  273. return f'LLM(model={self.model_name}, api_version={self.api_version}, base_url={self.base_url})'
  274. elif self.base_url:
  275. return f'LLM(model={self.model_name}, base_url={self.base_url})'
  276. return f'LLM(model={self.model_name})'
  277. def __repr__(self):
  278. return str(self)