llm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. import asyncio
  2. import copy
  3. import warnings
  4. from functools import partial
  5. from openhands.core.config import LLMConfig
  6. with warnings.catch_warnings():
  7. warnings.simplefilter('ignore')
  8. import litellm
  9. from litellm import completion as litellm_completion
  10. from litellm import completion_cost as litellm_completion_cost
  11. from litellm.exceptions import (
  12. APIConnectionError,
  13. ContentPolicyViolationError,
  14. InternalServerError,
  15. OpenAIError,
  16. RateLimitError,
  17. ServiceUnavailableError,
  18. )
  19. from litellm.types.utils import CostPerToken
  20. from tenacity import (
  21. retry,
  22. retry_if_exception_type,
  23. stop_after_attempt,
  24. wait_random_exponential,
  25. )
  26. from openhands.core.exceptions import UserCancelledError
  27. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  28. from openhands.core.logger import openhands_logger as logger
  29. from openhands.core.metrics import Metrics
  30. __all__ = ['LLM']
  31. message_separator = '\n\n----------\n\n'
  32. cache_prompting_supported_models = [
  33. 'claude-3-5-sonnet-20240620',
  34. 'claude-3-haiku-20240307',
  35. ]
  36. class LLM:
  37. """The LLM class represents a Language Model instance.
  38. Attributes:
  39. config: an LLMConfig object specifying the configuration of the LLM.
  40. """
  41. def __init__(
  42. self,
  43. config: LLMConfig,
  44. metrics: Metrics | None = None,
  45. ):
  46. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  47. Passing simple parameters always overrides config.
  48. Args:
  49. config: The LLM configuration
  50. """
  51. self.config = copy.deepcopy(config)
  52. self.metrics = metrics if metrics is not None else Metrics()
  53. self.cost_metric_supported = True
  54. self.supports_prompt_caching = (
  55. self.config.model in cache_prompting_supported_models
  56. )
  57. # Set up config attributes with default values to prevent AttributeError
  58. LLMConfig.set_missing_attributes(self.config)
  59. # litellm actually uses base Exception here for unknown model
  60. self.model_info = None
  61. try:
  62. if self.config.model.startswith('openrouter'):
  63. self.model_info = litellm.get_model_info(self.config.model)
  64. else:
  65. self.model_info = litellm.get_model_info(
  66. self.config.model.split(':')[0]
  67. )
  68. # noinspection PyBroadException
  69. except Exception as e:
  70. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  71. # Set the max tokens in an LM-specific way if not set
  72. if self.config.max_input_tokens is None:
  73. if (
  74. self.model_info is not None
  75. and 'max_input_tokens' in self.model_info
  76. and isinstance(self.model_info['max_input_tokens'], int)
  77. ):
  78. self.config.max_input_tokens = self.model_info['max_input_tokens']
  79. else:
  80. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  81. self.config.max_input_tokens = 4096
  82. if self.config.max_output_tokens is None:
  83. if (
  84. self.model_info is not None
  85. and 'max_output_tokens' in self.model_info
  86. and isinstance(self.model_info['max_output_tokens'], int)
  87. ):
  88. self.config.max_output_tokens = self.model_info['max_output_tokens']
  89. else:
  90. # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  91. self.config.max_output_tokens = 1024
  92. if self.config.drop_params:
  93. litellm.drop_params = self.config.drop_params
  94. self._completion = partial(
  95. litellm_completion,
  96. model=self.config.model,
  97. api_key=self.config.api_key,
  98. base_url=self.config.base_url,
  99. api_version=self.config.api_version,
  100. custom_llm_provider=self.config.custom_llm_provider,
  101. max_tokens=self.config.max_output_tokens,
  102. timeout=self.config.timeout,
  103. temperature=self.config.temperature,
  104. top_p=self.config.top_p,
  105. )
  106. completion_unwrapped = self._completion
  107. def attempt_on_error(retry_state):
  108. logger.error(
  109. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  110. exc_info=False,
  111. )
  112. return None
  113. @retry(
  114. reraise=True,
  115. stop=stop_after_attempt(self.config.num_retries),
  116. wait=wait_random_exponential(
  117. multiplier=self.config.retry_multiplier,
  118. min=self.config.retry_min_wait,
  119. max=self.config.retry_max_wait,
  120. ),
  121. retry=retry_if_exception_type(
  122. (
  123. RateLimitError,
  124. APIConnectionError,
  125. ServiceUnavailableError,
  126. InternalServerError,
  127. ContentPolicyViolationError,
  128. )
  129. ),
  130. after=attempt_on_error,
  131. )
  132. def wrapper(*args, **kwargs):
  133. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  134. # some callers might just send the messages directly
  135. if 'messages' in kwargs:
  136. messages = kwargs['messages']
  137. else:
  138. messages = args[1]
  139. # log the prompt
  140. debug_message = ''
  141. for message in messages:
  142. content = message['content']
  143. if isinstance(content, list):
  144. for element in content:
  145. if isinstance(element, dict):
  146. if 'text' in element:
  147. content_str = element['text'].strip()
  148. elif (
  149. 'image_url' in element and 'url' in element['image_url']
  150. ):
  151. content_str = element['image_url']['url']
  152. else:
  153. content_str = str(element)
  154. else:
  155. content_str = str(element)
  156. debug_message += message_separator + content_str
  157. else:
  158. content_str = str(content)
  159. debug_message += message_separator + content_str
  160. llm_prompt_logger.debug(debug_message)
  161. # skip if messages is empty (thus debug_message is empty)
  162. if debug_message:
  163. resp = completion_unwrapped(*args, **kwargs)
  164. else:
  165. resp = {'choices': [{'message': {'content': ''}}]}
  166. # log the response
  167. message_back = resp['choices'][0]['message']['content']
  168. llm_response_logger.debug(message_back)
  169. # post-process to log costs
  170. self._post_completion(resp)
  171. return resp
  172. self._completion = wrapper # type: ignore
  173. # Async version
  174. self._async_completion = partial(
  175. self._call_acompletion,
  176. model=self.config.model,
  177. api_key=self.config.api_key,
  178. base_url=self.config.base_url,
  179. api_version=self.config.api_version,
  180. custom_llm_provider=self.config.custom_llm_provider,
  181. max_tokens=self.config.max_output_tokens,
  182. timeout=self.config.timeout,
  183. temperature=self.config.temperature,
  184. top_p=self.config.top_p,
  185. drop_params=True,
  186. )
  187. async_completion_unwrapped = self._async_completion
  188. @retry(
  189. reraise=True,
  190. stop=stop_after_attempt(self.config.num_retries),
  191. wait=wait_random_exponential(
  192. multiplier=self.config.retry_multiplier,
  193. min=self.config.retry_min_wait,
  194. max=self.config.retry_max_wait,
  195. ),
  196. retry=retry_if_exception_type(
  197. (
  198. RateLimitError,
  199. APIConnectionError,
  200. ServiceUnavailableError,
  201. InternalServerError,
  202. ContentPolicyViolationError,
  203. )
  204. ),
  205. after=attempt_on_error,
  206. )
  207. async def async_completion_wrapper(*args, **kwargs):
  208. """Async wrapper for the litellm acompletion function."""
  209. # some callers might just send the messages directly
  210. if 'messages' in kwargs:
  211. messages = kwargs['messages']
  212. else:
  213. messages = args[1]
  214. # log the prompt
  215. debug_message = ''
  216. for message in messages:
  217. content = message['content']
  218. if isinstance(content, list):
  219. for element in content:
  220. if isinstance(element, dict):
  221. if 'text' in element:
  222. content_str = element['text']
  223. elif (
  224. 'image_url' in element and 'url' in element['image_url']
  225. ):
  226. content_str = element['image_url']['url']
  227. else:
  228. content_str = str(element)
  229. else:
  230. content_str = str(element)
  231. debug_message += message_separator + content_str
  232. else:
  233. content_str = str(content)
  234. debug_message += message_separator + content_str
  235. llm_prompt_logger.debug(debug_message)
  236. async def check_stopped():
  237. while True:
  238. if (
  239. hasattr(self.config, 'on_cancel_requested_fn')
  240. and self.config.on_cancel_requested_fn is not None
  241. and await self.config.on_cancel_requested_fn()
  242. ):
  243. raise UserCancelledError('LLM request cancelled by user')
  244. await asyncio.sleep(0.1)
  245. stop_check_task = asyncio.create_task(check_stopped())
  246. try:
  247. # Directly call and await litellm_acompletion
  248. resp = await async_completion_unwrapped(*args, **kwargs)
  249. # skip if messages is empty (thus debug_message is empty)
  250. if debug_message:
  251. message_back = resp['choices'][0]['message']['content']
  252. llm_response_logger.debug(message_back)
  253. else:
  254. resp = {'choices': [{'message': {'content': ''}}]}
  255. self._post_completion(resp)
  256. # We do not support streaming in this method, thus return resp
  257. return resp
  258. except UserCancelledError:
  259. logger.info('LLM request cancelled by user.')
  260. raise
  261. except OpenAIError as e:
  262. logger.error(f'OpenAIError occurred:\n{e}')
  263. raise
  264. except (
  265. RateLimitError,
  266. APIConnectionError,
  267. ServiceUnavailableError,
  268. InternalServerError,
  269. ) as e:
  270. logger.error(f'Completion Error occurred:\n{e}')
  271. raise
  272. finally:
  273. await asyncio.sleep(0.1)
  274. stop_check_task.cancel()
  275. try:
  276. await stop_check_task
  277. except asyncio.CancelledError:
  278. pass
  279. @retry(
  280. reraise=True,
  281. stop=stop_after_attempt(self.config.num_retries),
  282. wait=wait_random_exponential(
  283. multiplier=self.config.retry_multiplier,
  284. min=self.config.retry_min_wait,
  285. max=self.config.retry_max_wait,
  286. ),
  287. retry=retry_if_exception_type(
  288. (
  289. RateLimitError,
  290. APIConnectionError,
  291. ServiceUnavailableError,
  292. InternalServerError,
  293. ContentPolicyViolationError,
  294. )
  295. ),
  296. after=attempt_on_error,
  297. )
  298. async def async_acompletion_stream_wrapper(*args, **kwargs):
  299. """Async wrapper for the litellm acompletion with streaming function."""
  300. # some callers might just send the messages directly
  301. if 'messages' in kwargs:
  302. messages = kwargs['messages']
  303. else:
  304. messages = args[1]
  305. # log the prompt
  306. debug_message = ''
  307. for message in messages:
  308. debug_message += message_separator + message['content']
  309. llm_prompt_logger.debug(debug_message)
  310. try:
  311. # Directly call and await litellm_acompletion
  312. resp = await async_completion_unwrapped(*args, **kwargs)
  313. # For streaming we iterate over the chunks
  314. async for chunk in resp:
  315. # Check for cancellation before yielding the chunk
  316. if (
  317. hasattr(self.config, 'on_cancel_requested_fn')
  318. and self.config.on_cancel_requested_fn is not None
  319. and await self.config.on_cancel_requested_fn()
  320. ):
  321. raise UserCancelledError(
  322. 'LLM request cancelled due to CANCELLED state'
  323. )
  324. # with streaming, it is "delta", not "message"!
  325. message_back = chunk['choices'][0]['delta']['content']
  326. llm_response_logger.debug(message_back)
  327. self._post_completion(chunk)
  328. yield chunk
  329. except UserCancelledError:
  330. logger.info('LLM request cancelled by user.')
  331. raise
  332. except OpenAIError as e:
  333. logger.error(f'OpenAIError occurred:\n{e}')
  334. raise
  335. except (
  336. RateLimitError,
  337. APIConnectionError,
  338. ServiceUnavailableError,
  339. InternalServerError,
  340. ) as e:
  341. logger.error(f'Completion Error occurred:\n{e}')
  342. raise
  343. finally:
  344. if kwargs.get('stream', False):
  345. await asyncio.sleep(0.1)
  346. self._async_completion = async_completion_wrapper # type: ignore
  347. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  348. async def _call_acompletion(self, *args, **kwargs):
  349. return await litellm.acompletion(*args, **kwargs)
  350. @property
  351. def completion(self):
  352. """Decorator for the litellm completion function.
  353. Check the complete documentation at https://litellm.vercel.app/docs/completion
  354. """
  355. return self._completion
  356. @property
  357. def async_completion(self):
  358. """Decorator for the async litellm acompletion function.
  359. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  360. """
  361. return self._async_completion
  362. @property
  363. def async_streaming_completion(self):
  364. """Decorator for the async litellm acompletion function with streaming.
  365. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  366. """
  367. return self._async_streaming_completion
  368. def supports_vision(self):
  369. return litellm.supports_vision(self.config.model)
  370. def _post_completion(self, response) -> None:
  371. """Post-process the completion response."""
  372. try:
  373. cur_cost = self.completion_cost(response)
  374. except Exception:
  375. cur_cost = 0
  376. stats = ''
  377. if self.cost_metric_supported:
  378. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  379. cur_cost,
  380. self.metrics.accumulated_cost,
  381. )
  382. usage = response.get('usage')
  383. if usage:
  384. input_tokens = usage.get('prompt_tokens')
  385. output_tokens = usage.get('completion_tokens')
  386. if input_tokens:
  387. stats += 'Input tokens: ' + str(input_tokens) + '\n'
  388. if output_tokens:
  389. stats += 'Output tokens: ' + str(output_tokens) + '\n'
  390. model_extra = usage.get('model_extra', {})
  391. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  392. if cache_creation_input_tokens:
  393. stats += (
  394. 'Input tokens (cache write): '
  395. + str(cache_creation_input_tokens)
  396. + '\n'
  397. )
  398. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  399. if cache_read_input_tokens:
  400. stats += (
  401. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  402. )
  403. if stats:
  404. logger.info(stats)
  405. def get_token_count(self, messages):
  406. """Get the number of tokens in a list of messages.
  407. Args:
  408. messages (list): A list of messages.
  409. Returns:
  410. int: The number of tokens.
  411. """
  412. return litellm.token_counter(model=self.config.model, messages=messages)
  413. def is_local(self):
  414. """Determines if the system is using a locally running LLM.
  415. Returns:
  416. boolean: True if executing a local model.
  417. """
  418. if self.config.base_url is not None:
  419. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  420. if substring in self.config.base_url:
  421. return True
  422. elif self.config.model is not None:
  423. if self.config.model.startswith('ollama'):
  424. return True
  425. return False
  426. def completion_cost(self, response):
  427. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  428. Add the current cost into total cost in metrics.
  429. Args:
  430. response: A response from a model invocation.
  431. Returns:
  432. number: The cost of the response.
  433. """
  434. if not self.cost_metric_supported:
  435. return 0.0
  436. extra_kwargs = {}
  437. if (
  438. self.config.input_cost_per_token is not None
  439. and self.config.output_cost_per_token is not None
  440. ):
  441. cost_per_token = CostPerToken(
  442. input_cost_per_token=self.config.input_cost_per_token,
  443. output_cost_per_token=self.config.output_cost_per_token,
  444. )
  445. logger.info(f'Using custom cost per token: {cost_per_token}')
  446. extra_kwargs['custom_cost_per_token'] = cost_per_token
  447. if not self.is_local():
  448. try:
  449. cost = litellm_completion_cost(
  450. completion_response=response, **extra_kwargs
  451. )
  452. self.metrics.add_cost(cost)
  453. return cost
  454. except Exception:
  455. self.cost_metric_supported = False
  456. logger.warning('Cost calculation not supported for this model.')
  457. return 0.0
  458. def __str__(self):
  459. if self.config.api_version:
  460. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  461. elif self.config.base_url:
  462. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  463. return f'LLM(model={self.config.model})'
  464. def __repr__(self):
  465. return str(self)
  466. def reset(self):
  467. self.metrics = Metrics()