llm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. import copy
  2. import os
  3. import time
  4. import warnings
  5. from functools import partial
  6. from typing import Any
  7. import requests
  8. from openhands.core.config import LLMConfig
  9. with warnings.catch_warnings():
  10. warnings.simplefilter('ignore')
  11. import litellm
  12. from litellm import Message as LiteLLMMessage
  13. from litellm import ModelInfo, PromptTokensDetails
  14. from litellm import completion as litellm_completion
  15. from litellm import completion_cost as litellm_completion_cost
  16. from litellm.exceptions import (
  17. APIConnectionError,
  18. APIError,
  19. InternalServerError,
  20. RateLimitError,
  21. ServiceUnavailableError,
  22. )
  23. from litellm.types.utils import CostPerToken, ModelResponse, Usage
  24. from litellm.utils import create_pretrained_tokenizer
  25. from openhands.core.exceptions import CloudFlareBlockageError
  26. from openhands.core.logger import openhands_logger as logger
  27. from openhands.core.message import Message
  28. from openhands.llm.debug_mixin import DebugMixin
  29. from openhands.llm.fn_call_converter import (
  30. STOP_WORDS,
  31. convert_fncall_messages_to_non_fncall_messages,
  32. convert_non_fncall_messages_to_fncall_messages,
  33. )
  34. from openhands.llm.metrics import Metrics
  35. from openhands.llm.retry_mixin import RetryMixin
  36. __all__ = ['LLM']
  37. # tuple of exceptions to retry on
  38. LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
  39. APIConnectionError,
  40. # FIXME: APIError is useful on 502 from a proxy for example,
  41. # but it also retries on other errors that are permanent
  42. APIError,
  43. InternalServerError,
  44. RateLimitError,
  45. ServiceUnavailableError,
  46. )
  47. # cache prompt supporting models
  48. # remove this when we gemini and deepseek are supported
  49. CACHE_PROMPT_SUPPORTED_MODELS = [
  50. 'claude-3-5-sonnet-20241022',
  51. 'claude-3-5-sonnet-20240620',
  52. 'claude-3-5-haiku-20241022',
  53. 'claude-3-haiku-20240307',
  54. 'claude-3-opus-20240229',
  55. ]
  56. # function calling supporting models
  57. FUNCTION_CALLING_SUPPORTED_MODELS = [
  58. 'claude-3-5-sonnet',
  59. 'claude-3-5-sonnet-20240620',
  60. 'claude-3-5-sonnet-20241022',
  61. 'claude-3.5-haiku',
  62. 'claude-3-5-haiku-20241022',
  63. 'gpt-4o-mini',
  64. 'gpt-4o',
  65. ]
  66. class LLM(RetryMixin, DebugMixin):
  67. """The LLM class represents a Language Model instance.
  68. Attributes:
  69. config: an LLMConfig object specifying the configuration of the LLM.
  70. """
  71. def __init__(
  72. self,
  73. config: LLMConfig,
  74. metrics: Metrics | None = None,
  75. ):
  76. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  77. Passing simple parameters always overrides config.
  78. Args:
  79. config: The LLM configuration.
  80. metrics: The metrics to use.
  81. """
  82. self._tried_model_info = False
  83. self.metrics: Metrics = (
  84. metrics if metrics is not None else Metrics(model_name=config.model)
  85. )
  86. self.cost_metric_supported: bool = True
  87. self.config: LLMConfig = copy.deepcopy(config)
  88. self.model_info: ModelInfo | None = None
  89. if self.config.log_completions:
  90. if self.config.log_completions_folder is None:
  91. raise RuntimeError(
  92. 'log_completions_folder is required when log_completions is enabled'
  93. )
  94. os.makedirs(self.config.log_completions_folder, exist_ok=True)
  95. # call init_model_info to initialize config.max_output_tokens
  96. # which is used in partial function
  97. with warnings.catch_warnings():
  98. warnings.simplefilter('ignore')
  99. self.init_model_info()
  100. if self.vision_is_active():
  101. logger.debug('LLM: model has vision enabled')
  102. if self.is_caching_prompt_active():
  103. logger.debug('LLM: caching prompt enabled')
  104. if self.is_function_calling_active():
  105. logger.debug('LLM: model supports function calling')
  106. # Compatibility flag: use string serializer for DeepSeek models
  107. # See this issue: https://github.com/All-Hands-AI/OpenHands/issues/5818
  108. self._use_string_serializer = False
  109. if 'deepseek' in self.config.model:
  110. self._use_string_serializer = True
  111. # if using a custom tokenizer, make sure it's loaded and accessible in the format expected by litellm
  112. if self.config.custom_tokenizer is not None:
  113. self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
  114. else:
  115. self.tokenizer = None
  116. # set up the completion function
  117. self._completion = partial(
  118. litellm_completion,
  119. model=self.config.model,
  120. api_key=self.config.api_key,
  121. base_url=self.config.base_url,
  122. api_version=self.config.api_version,
  123. custom_llm_provider=self.config.custom_llm_provider,
  124. max_tokens=self.config.max_output_tokens,
  125. timeout=self.config.timeout,
  126. temperature=self.config.temperature,
  127. top_p=self.config.top_p,
  128. drop_params=self.config.drop_params,
  129. )
  130. self._completion_unwrapped = self._completion
  131. @self.retry_decorator(
  132. num_retries=self.config.num_retries,
  133. retry_exceptions=LLM_RETRY_EXCEPTIONS,
  134. retry_min_wait=self.config.retry_min_wait,
  135. retry_max_wait=self.config.retry_max_wait,
  136. retry_multiplier=self.config.retry_multiplier,
  137. )
  138. def wrapper(*args, **kwargs):
  139. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  140. from openhands.core.utils import json
  141. messages: list[dict[str, Any]] | dict[str, Any] = []
  142. mock_function_calling = kwargs.pop('mock_function_calling', False)
  143. # some callers might send the model and messages directly
  144. # litellm allows positional args, like completion(model, messages, **kwargs)
  145. if len(args) > 1:
  146. # ignore the first argument if it's provided (it would be the model)
  147. # design wise: we don't allow overriding the configured values
  148. # implementation wise: the partial function set the model as a kwarg already
  149. # as well as other kwargs
  150. messages = args[1] if len(args) > 1 else args[0]
  151. kwargs['messages'] = messages
  152. # remove the first args, they're sent in kwargs
  153. args = args[2:]
  154. elif 'messages' in kwargs:
  155. messages = kwargs['messages']
  156. # ensure we work with a list of messages
  157. messages = messages if isinstance(messages, list) else [messages]
  158. original_fncall_messages = copy.deepcopy(messages)
  159. mock_fncall_tools = None
  160. if mock_function_calling:
  161. assert (
  162. 'tools' in kwargs
  163. ), "'tools' must be in kwargs when mock_function_calling is True"
  164. messages = convert_fncall_messages_to_non_fncall_messages(
  165. messages, kwargs['tools']
  166. )
  167. kwargs['messages'] = messages
  168. kwargs['stop'] = STOP_WORDS
  169. mock_fncall_tools = kwargs.pop('tools')
  170. # if we have no messages, something went very wrong
  171. if not messages:
  172. raise ValueError(
  173. 'The messages list is empty. At least one message is required.'
  174. )
  175. # log the entire LLM prompt
  176. self.log_prompt(messages)
  177. if self.is_caching_prompt_active():
  178. # Anthropic-specific prompt caching
  179. if 'claude-3' in self.config.model:
  180. kwargs['extra_headers'] = {
  181. 'anthropic-beta': 'prompt-caching-2024-07-31',
  182. }
  183. # set litellm modify_params to the configured value
  184. # True by default to allow litellm to do transformations like adding a default message, when a message is empty
  185. # NOTE: this setting is global; unlike drop_params, it cannot be overridden in the litellm completion partial
  186. litellm.modify_params = self.config.modify_params
  187. try:
  188. # Record start time for latency measurement
  189. start_time = time.time()
  190. # we don't support streaming here, thus we get a ModelResponse
  191. resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
  192. # Calculate and record latency
  193. latency = time.time() - start_time
  194. response_id = resp.get('id', 'unknown')
  195. self.metrics.add_response_latency(latency, response_id)
  196. non_fncall_response = copy.deepcopy(resp)
  197. if mock_function_calling:
  198. assert len(resp.choices) == 1
  199. assert mock_fncall_tools is not None
  200. non_fncall_response_message = resp.choices[0].message
  201. fn_call_messages_with_response = (
  202. convert_non_fncall_messages_to_fncall_messages(
  203. messages + [non_fncall_response_message], mock_fncall_tools
  204. )
  205. )
  206. fn_call_response_message = fn_call_messages_with_response[-1]
  207. if not isinstance(fn_call_response_message, LiteLLMMessage):
  208. fn_call_response_message = LiteLLMMessage(
  209. **fn_call_response_message
  210. )
  211. resp.choices[0].message = fn_call_response_message
  212. message_back: str = resp['choices'][0]['message']['content'] or ''
  213. tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
  214. if tool_calls:
  215. for tool_call in tool_calls:
  216. fn_name = tool_call.function.name
  217. fn_args = tool_call.function.arguments
  218. message_back += f'\nFunction call: {fn_name}({fn_args})'
  219. # log the LLM response
  220. self.log_response(message_back)
  221. # post-process the response first to calculate cost
  222. cost = self._post_completion(resp)
  223. # log for evals or other scripts that need the raw completion
  224. if self.config.log_completions:
  225. assert self.config.log_completions_folder is not None
  226. log_file = os.path.join(
  227. self.config.log_completions_folder,
  228. # use the metric model name (for draft editor)
  229. f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
  230. )
  231. # set up the dict to be logged
  232. _d = {
  233. 'messages': messages,
  234. 'response': resp,
  235. 'args': args,
  236. 'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
  237. 'timestamp': time.time(),
  238. 'cost': cost,
  239. }
  240. # if non-native function calling, save messages/response separately
  241. if mock_function_calling:
  242. # Overwrite response as non-fncall to be consistent with messages
  243. _d['response'] = non_fncall_response
  244. # Save fncall_messages/response separately
  245. _d['fncall_messages'] = original_fncall_messages
  246. _d['fncall_response'] = resp
  247. with open(log_file, 'w') as f:
  248. f.write(json.dumps(_d))
  249. return resp
  250. except APIError as e:
  251. if 'Attention Required! | Cloudflare' in str(e):
  252. raise CloudFlareBlockageError(
  253. 'Request blocked by CloudFlare'
  254. ) from e
  255. raise
  256. self._completion = wrapper
  257. @property
  258. def completion(self):
  259. """Decorator for the litellm completion function.
  260. Check the complete documentation at https://litellm.vercel.app/docs/completion
  261. """
  262. return self._completion
  263. def init_model_info(self):
  264. if self._tried_model_info:
  265. return
  266. self._tried_model_info = True
  267. try:
  268. if self.config.model.startswith('openrouter'):
  269. self.model_info = litellm.get_model_info(self.config.model)
  270. except Exception as e:
  271. logger.debug(f'Error getting model info: {e}')
  272. if self.config.model.startswith('litellm_proxy/'):
  273. # IF we are using LiteLLM proxy, get model info from LiteLLM proxy
  274. # GET {base_url}/v1/model/info with litellm_model_id as path param
  275. response = requests.get(
  276. f'{self.config.base_url}/v1/model/info',
  277. headers={'Authorization': f'Bearer {self.config.api_key}'},
  278. )
  279. resp_json = response.json()
  280. if 'data' not in resp_json:
  281. logger.error(
  282. f'Error getting model info from LiteLLM proxy: {resp_json}'
  283. )
  284. all_model_info = resp_json.get('data', [])
  285. current_model_info = next(
  286. (
  287. info
  288. for info in all_model_info
  289. if info['model_name']
  290. == self.config.model.removeprefix('litellm_proxy/')
  291. ),
  292. None,
  293. )
  294. if current_model_info:
  295. self.model_info = current_model_info['model_info']
  296. # Last two attempts to get model info from NAME
  297. if not self.model_info:
  298. try:
  299. self.model_info = litellm.get_model_info(
  300. self.config.model.split(':')[0]
  301. )
  302. # noinspection PyBroadException
  303. except Exception:
  304. pass
  305. if not self.model_info:
  306. try:
  307. self.model_info = litellm.get_model_info(
  308. self.config.model.split('/')[-1]
  309. )
  310. # noinspection PyBroadException
  311. except Exception:
  312. pass
  313. logger.debug(f'Model info: {self.model_info}')
  314. if self.config.model.startswith('huggingface'):
  315. # HF doesn't support the OpenAI default value for top_p (1)
  316. logger.debug(
  317. f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
  318. )
  319. self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
  320. # Set the max tokens in an LM-specific way if not set
  321. if self.config.max_input_tokens is None:
  322. if (
  323. self.model_info is not None
  324. and 'max_input_tokens' in self.model_info
  325. and isinstance(self.model_info['max_input_tokens'], int)
  326. ):
  327. self.config.max_input_tokens = self.model_info['max_input_tokens']
  328. else:
  329. # Safe fallback for any potentially viable model
  330. self.config.max_input_tokens = 4096
  331. if self.config.max_output_tokens is None:
  332. # Safe default for any potentially viable model
  333. self.config.max_output_tokens = 4096
  334. if self.model_info is not None:
  335. # max_output_tokens has precedence over max_tokens, if either exists.
  336. # litellm has models with both, one or none of these 2 parameters!
  337. if 'max_output_tokens' in self.model_info and isinstance(
  338. self.model_info['max_output_tokens'], int
  339. ):
  340. self.config.max_output_tokens = self.model_info['max_output_tokens']
  341. elif 'max_tokens' in self.model_info and isinstance(
  342. self.model_info['max_tokens'], int
  343. ):
  344. self.config.max_output_tokens = self.model_info['max_tokens']
  345. def vision_is_active(self) -> bool:
  346. with warnings.catch_warnings():
  347. warnings.simplefilter('ignore')
  348. return not self.config.disable_vision and self._supports_vision()
  349. def _supports_vision(self) -> bool:
  350. """Acquire from litellm if model is vision capable.
  351. Returns:
  352. bool: True if model is vision capable. Return False if model not supported by litellm.
  353. """
  354. # litellm.supports_vision currently returns False for 'openai/gpt-...' or 'anthropic/claude-...' (with prefixes)
  355. # but model_info will have the correct value for some reason.
  356. # we can go with it, but we will need to keep an eye if model_info is correct for Vertex or other providers
  357. # remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
  358. # Check both the full model name and the name after proxy prefix for vision support
  359. return (
  360. litellm.supports_vision(self.config.model)
  361. or litellm.supports_vision(self.config.model.split('/')[-1])
  362. or (
  363. self.model_info is not None
  364. and self.model_info.get('supports_vision', False)
  365. )
  366. )
  367. def is_caching_prompt_active(self) -> bool:
  368. """Check if prompt caching is supported and enabled for current model.
  369. Returns:
  370. boolean: True if prompt caching is supported and enabled for the given model.
  371. """
  372. return (
  373. self.config.caching_prompt is True
  374. and (
  375. self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
  376. or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
  377. )
  378. # We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint
  379. )
  380. def is_function_calling_active(self) -> bool:
  381. # Check if model name is in supported list before checking model_info
  382. model_name_supported = (
  383. self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
  384. or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
  385. or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
  386. )
  387. return model_name_supported
  388. def _post_completion(self, response: ModelResponse) -> float:
  389. """Post-process the completion response.
  390. Logs the cost and usage stats of the completion call.
  391. """
  392. try:
  393. cur_cost = self._completion_cost(response)
  394. except Exception:
  395. cur_cost = 0
  396. stats = ''
  397. if self.cost_metric_supported:
  398. # keep track of the cost
  399. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  400. cur_cost,
  401. self.metrics.accumulated_cost,
  402. )
  403. # Add latency to stats if available
  404. if self.metrics.response_latencies:
  405. latest_latency = self.metrics.response_latencies[-1]
  406. stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
  407. usage: Usage | None = response.get('usage')
  408. if usage:
  409. # keep track of the input and output tokens
  410. input_tokens = usage.get('prompt_tokens')
  411. output_tokens = usage.get('completion_tokens')
  412. if input_tokens:
  413. stats += 'Input tokens: ' + str(input_tokens)
  414. if output_tokens:
  415. stats += (
  416. (' | ' if input_tokens else '')
  417. + 'Output tokens: '
  418. + str(output_tokens)
  419. + '\n'
  420. )
  421. # read the prompt cache hit, if any
  422. prompt_tokens_details: PromptTokensDetails = usage.get(
  423. 'prompt_tokens_details'
  424. )
  425. cache_hit_tokens = (
  426. prompt_tokens_details.cached_tokens if prompt_tokens_details else None
  427. )
  428. if cache_hit_tokens:
  429. stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
  430. # For Anthropic, the cache writes have a different cost than regular input tokens
  431. # but litellm doesn't separate them in the usage stats
  432. # so we can read it from the provider-specific extra field
  433. model_extra = usage.get('model_extra', {})
  434. cache_write_tokens = model_extra.get('cache_creation_input_tokens')
  435. if cache_write_tokens:
  436. stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
  437. # log the stats
  438. if stats:
  439. logger.debug(stats)
  440. return cur_cost
  441. def get_token_count(self, messages: list[dict] | list[Message]) -> int:
  442. """Get the number of tokens in a list of messages. Use dicts for better token counting.
  443. Args:
  444. messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
  445. Returns:
  446. int: The number of tokens.
  447. """
  448. # attempt to convert Message objects to dicts, litellm expects dicts
  449. if (
  450. isinstance(messages, list)
  451. and len(messages) > 0
  452. and isinstance(messages[0], Message)
  453. ):
  454. logger.info(
  455. 'Message objects now include serialized tool calls in token counting'
  456. )
  457. messages = self.format_messages_for_llm(messages) # type: ignore
  458. # try to get the token count with the default litellm tokenizers
  459. # or the custom tokenizer if set for this LLM configuration
  460. try:
  461. return litellm.token_counter(
  462. model=self.config.model,
  463. messages=messages,
  464. custom_tokenizer=self.tokenizer,
  465. )
  466. except Exception as e:
  467. # limit logspam in case token count is not supported
  468. logger.error(
  469. f'Error getting token count for\n model {self.config.model}\n{e}'
  470. + (
  471. f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
  472. if self.config.custom_tokenizer is not None
  473. else ''
  474. )
  475. )
  476. return 0
  477. def _is_local(self) -> bool:
  478. """Determines if the system is using a locally running LLM.
  479. Returns:
  480. boolean: True if executing a local model.
  481. """
  482. if self.config.base_url is not None:
  483. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  484. if substring in self.config.base_url:
  485. return True
  486. elif self.config.model is not None:
  487. if self.config.model.startswith('ollama'):
  488. return True
  489. return False
  490. def _completion_cost(self, response) -> float:
  491. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  492. Add the current cost into total cost in metrics.
  493. Args:
  494. response: A response from a model invocation.
  495. Returns:
  496. number: The cost of the response.
  497. """
  498. if not self.cost_metric_supported:
  499. return 0.0
  500. extra_kwargs = {}
  501. if (
  502. self.config.input_cost_per_token is not None
  503. and self.config.output_cost_per_token is not None
  504. ):
  505. cost_per_token = CostPerToken(
  506. input_cost_per_token=self.config.input_cost_per_token,
  507. output_cost_per_token=self.config.output_cost_per_token,
  508. )
  509. logger.debug(f'Using custom cost per token: {cost_per_token}')
  510. extra_kwargs['custom_cost_per_token'] = cost_per_token
  511. try:
  512. # try directly get response_cost from response
  513. cost = getattr(response, '_hidden_params', {}).get('response_cost', None)
  514. if cost is None:
  515. cost = litellm_completion_cost(
  516. completion_response=response, **extra_kwargs
  517. )
  518. self.metrics.add_cost(cost)
  519. return cost
  520. except Exception:
  521. self.cost_metric_supported = False
  522. logger.debug('Cost calculation not supported for this model.')
  523. return 0.0
  524. def __str__(self):
  525. if self.config.api_version:
  526. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  527. elif self.config.base_url:
  528. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  529. return f'LLM(model={self.config.model})'
  530. def __repr__(self):
  531. return str(self)
  532. def reset(self) -> None:
  533. self.metrics.reset()
  534. def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
  535. if isinstance(messages, Message):
  536. messages = [messages]
  537. # set flags to know how to serialize the messages
  538. for message in messages:
  539. message.cache_enabled = self.is_caching_prompt_active()
  540. message.vision_enabled = self.vision_is_active()
  541. message.function_calling_enabled = self.is_function_calling_active()
  542. if 'deepseek' in self.config.model:
  543. message.force_string_serializer = True
  544. # let pydantic handle the serialization
  545. return [message.model_dump() for message in messages]