llm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. import asyncio
  2. import copy
  3. import warnings
  4. from functools import partial
  5. from openhands.core.config import LLMConfig
  6. with warnings.catch_warnings():
  7. warnings.simplefilter('ignore')
  8. import litellm
  9. from litellm import completion as litellm_completion
  10. from litellm import completion_cost as litellm_completion_cost
  11. from litellm.exceptions import (
  12. APIConnectionError,
  13. ContentPolicyViolationError,
  14. InternalServerError,
  15. NotFoundError,
  16. OpenAIError,
  17. RateLimitError,
  18. ServiceUnavailableError,
  19. )
  20. from litellm.types.utils import CostPerToken
  21. from tenacity import (
  22. retry,
  23. retry_if_exception_type,
  24. stop_after_attempt,
  25. wait_random_exponential,
  26. )
  27. from openhands.core.exceptions import LLMResponseError, UserCancelledError
  28. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  29. from openhands.core.logger import openhands_logger as logger
  30. from openhands.core.metrics import Metrics
  31. __all__ = ['LLM']
  32. message_separator = '\n\n----------\n\n'
  33. cache_prompting_supported_models = [
  34. 'claude-3-5-sonnet-20240620',
  35. 'claude-3-haiku-20240307',
  36. ]
  37. class LLM:
  38. """The LLM class represents a Language Model instance.
  39. Attributes:
  40. config: an LLMConfig object specifying the configuration of the LLM.
  41. """
  42. def __init__(
  43. self,
  44. config: LLMConfig,
  45. metrics: Metrics | None = None,
  46. ):
  47. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  48. Passing simple parameters always overrides config.
  49. Args:
  50. config: The LLM configuration
  51. """
  52. self.config = copy.deepcopy(config)
  53. self.metrics = metrics if metrics is not None else Metrics()
  54. self.cost_metric_supported = True
  55. self.supports_prompt_caching = (
  56. self.config.model in cache_prompting_supported_models
  57. )
  58. # Set up config attributes with default values to prevent AttributeError
  59. LLMConfig.set_missing_attributes(self.config)
  60. # litellm actually uses base Exception here for unknown model
  61. self.model_info = None
  62. try:
  63. if self.config.model.startswith('openrouter'):
  64. self.model_info = litellm.get_model_info(self.config.model)
  65. else:
  66. self.model_info = litellm.get_model_info(
  67. self.config.model.split(':')[0]
  68. )
  69. # noinspection PyBroadException
  70. except Exception as e:
  71. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  72. # Set the max tokens in an LM-specific way if not set
  73. if self.config.max_input_tokens is None:
  74. if (
  75. self.model_info is not None
  76. and 'max_input_tokens' in self.model_info
  77. and isinstance(self.model_info['max_input_tokens'], int)
  78. ):
  79. self.config.max_input_tokens = self.model_info['max_input_tokens']
  80. else:
  81. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  82. self.config.max_input_tokens = 4096
  83. if self.config.max_output_tokens is None:
  84. if (
  85. self.model_info is not None
  86. and 'max_output_tokens' in self.model_info
  87. and isinstance(self.model_info['max_output_tokens'], int)
  88. ):
  89. self.config.max_output_tokens = self.model_info['max_output_tokens']
  90. else:
  91. # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  92. self.config.max_output_tokens = 1024
  93. if self.config.drop_params:
  94. litellm.drop_params = self.config.drop_params
  95. self._completion = partial(
  96. litellm_completion,
  97. model=self.config.model,
  98. api_key=self.config.api_key,
  99. base_url=self.config.base_url,
  100. api_version=self.config.api_version,
  101. custom_llm_provider=self.config.custom_llm_provider,
  102. max_tokens=self.config.max_output_tokens,
  103. timeout=self.config.timeout,
  104. temperature=self.config.temperature,
  105. top_p=self.config.top_p,
  106. )
  107. completion_unwrapped = self._completion
  108. def attempt_on_error(retry_state):
  109. logger.error(
  110. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  111. exc_info=False,
  112. )
  113. return None
  114. @retry(
  115. reraise=True,
  116. stop=stop_after_attempt(self.config.num_retries),
  117. wait=wait_random_exponential(
  118. multiplier=self.config.retry_multiplier,
  119. min=self.config.retry_min_wait,
  120. max=self.config.retry_max_wait,
  121. ),
  122. retry=retry_if_exception_type(
  123. (
  124. APIConnectionError,
  125. ContentPolicyViolationError,
  126. InternalServerError,
  127. OpenAIError,
  128. RateLimitError,
  129. )
  130. ),
  131. after=attempt_on_error,
  132. )
  133. def wrapper(*args, **kwargs):
  134. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  135. # some callers might just send the messages directly
  136. if 'messages' in kwargs:
  137. messages = kwargs['messages']
  138. else:
  139. messages = args[1]
  140. # log the prompt
  141. debug_message = ''
  142. for message in messages:
  143. content = message['content']
  144. if isinstance(content, list):
  145. for element in content:
  146. if isinstance(element, dict):
  147. if 'text' in element:
  148. content_str = element['text'].strip()
  149. elif (
  150. 'image_url' in element and 'url' in element['image_url']
  151. ):
  152. content_str = element['image_url']['url']
  153. else:
  154. content_str = str(element)
  155. else:
  156. content_str = str(element)
  157. debug_message += message_separator + content_str
  158. else:
  159. content_str = str(content)
  160. debug_message += message_separator + content_str
  161. llm_prompt_logger.debug(debug_message)
  162. # skip if messages is empty (thus debug_message is empty)
  163. if debug_message:
  164. resp = completion_unwrapped(*args, **kwargs)
  165. else:
  166. resp = {'choices': [{'message': {'content': ''}}]}
  167. # log the response
  168. message_back = resp['choices'][0]['message']['content']
  169. llm_response_logger.debug(message_back)
  170. # post-process to log costs
  171. self._post_completion(resp)
  172. return resp
  173. self._completion = wrapper # type: ignore
  174. # Async version
  175. self._async_completion = partial(
  176. self._call_acompletion,
  177. model=self.config.model,
  178. api_key=self.config.api_key,
  179. base_url=self.config.base_url,
  180. api_version=self.config.api_version,
  181. custom_llm_provider=self.config.custom_llm_provider,
  182. max_tokens=self.config.max_output_tokens,
  183. timeout=self.config.timeout,
  184. temperature=self.config.temperature,
  185. top_p=self.config.top_p,
  186. drop_params=True,
  187. )
  188. async_completion_unwrapped = self._async_completion
  189. @retry(
  190. reraise=True,
  191. stop=stop_after_attempt(self.config.num_retries),
  192. wait=wait_random_exponential(
  193. multiplier=self.config.retry_multiplier,
  194. min=self.config.retry_min_wait,
  195. max=self.config.retry_max_wait,
  196. ),
  197. retry=retry_if_exception_type(
  198. (
  199. APIConnectionError,
  200. ContentPolicyViolationError,
  201. InternalServerError,
  202. OpenAIError,
  203. RateLimitError,
  204. )
  205. ),
  206. after=attempt_on_error,
  207. )
  208. async def async_completion_wrapper(*args, **kwargs):
  209. """Async wrapper for the litellm acompletion function."""
  210. # some callers might just send the messages directly
  211. if 'messages' in kwargs:
  212. messages = kwargs['messages']
  213. else:
  214. messages = args[1]
  215. # log the prompt
  216. debug_message = ''
  217. for message in messages:
  218. content = message['content']
  219. if isinstance(content, list):
  220. for element in content:
  221. if isinstance(element, dict):
  222. if 'text' in element:
  223. content_str = element['text']
  224. elif (
  225. 'image_url' in element and 'url' in element['image_url']
  226. ):
  227. content_str = element['image_url']['url']
  228. else:
  229. content_str = str(element)
  230. else:
  231. content_str = str(element)
  232. debug_message += message_separator + content_str
  233. else:
  234. content_str = str(content)
  235. debug_message += message_separator + content_str
  236. llm_prompt_logger.debug(debug_message)
  237. async def check_stopped():
  238. while True:
  239. if (
  240. hasattr(self.config, 'on_cancel_requested_fn')
  241. and self.config.on_cancel_requested_fn is not None
  242. and await self.config.on_cancel_requested_fn()
  243. ):
  244. raise UserCancelledError('LLM request cancelled by user')
  245. await asyncio.sleep(0.1)
  246. stop_check_task = asyncio.create_task(check_stopped())
  247. try:
  248. # Directly call and await litellm_acompletion
  249. resp = await async_completion_unwrapped(*args, **kwargs)
  250. # skip if messages is empty (thus debug_message is empty)
  251. if debug_message:
  252. message_back = resp['choices'][0]['message']['content']
  253. llm_response_logger.debug(message_back)
  254. else:
  255. resp = {'choices': [{'message': {'content': ''}}]}
  256. self._post_completion(resp)
  257. # We do not support streaming in this method, thus return resp
  258. return resp
  259. except UserCancelledError:
  260. logger.info('LLM request cancelled by user.')
  261. raise
  262. except (
  263. APIConnectionError,
  264. ContentPolicyViolationError,
  265. InternalServerError,
  266. NotFoundError,
  267. OpenAIError,
  268. RateLimitError,
  269. ServiceUnavailableError,
  270. ) as e:
  271. logger.error(f'Completion Error occurred:\n{e}')
  272. raise
  273. finally:
  274. await asyncio.sleep(0.1)
  275. stop_check_task.cancel()
  276. try:
  277. await stop_check_task
  278. except asyncio.CancelledError:
  279. pass
  280. @retry(
  281. reraise=True,
  282. stop=stop_after_attempt(self.config.num_retries),
  283. wait=wait_random_exponential(
  284. multiplier=self.config.retry_multiplier,
  285. min=self.config.retry_min_wait,
  286. max=self.config.retry_max_wait,
  287. ),
  288. retry=retry_if_exception_type(
  289. (
  290. APIConnectionError,
  291. ContentPolicyViolationError,
  292. InternalServerError,
  293. OpenAIError,
  294. RateLimitError,
  295. )
  296. ),
  297. after=attempt_on_error,
  298. )
  299. async def async_acompletion_stream_wrapper(*args, **kwargs):
  300. """Async wrapper for the litellm acompletion with streaming function."""
  301. # some callers might just send the messages directly
  302. if 'messages' in kwargs:
  303. messages = kwargs['messages']
  304. else:
  305. messages = args[1]
  306. # log the prompt
  307. debug_message = ''
  308. for message in messages:
  309. debug_message += message_separator + message['content']
  310. llm_prompt_logger.debug(debug_message)
  311. try:
  312. # Directly call and await litellm_acompletion
  313. resp = await async_completion_unwrapped(*args, **kwargs)
  314. # For streaming we iterate over the chunks
  315. async for chunk in resp:
  316. # Check for cancellation before yielding the chunk
  317. if (
  318. hasattr(self.config, 'on_cancel_requested_fn')
  319. and self.config.on_cancel_requested_fn is not None
  320. and await self.config.on_cancel_requested_fn()
  321. ):
  322. raise UserCancelledError(
  323. 'LLM request cancelled due to CANCELLED state'
  324. )
  325. # with streaming, it is "delta", not "message"!
  326. message_back = chunk['choices'][0]['delta']['content']
  327. llm_response_logger.debug(message_back)
  328. self._post_completion(chunk)
  329. yield chunk
  330. except UserCancelledError:
  331. logger.info('LLM request cancelled by user.')
  332. raise
  333. except (
  334. APIConnectionError,
  335. ContentPolicyViolationError,
  336. InternalServerError,
  337. NotFoundError,
  338. OpenAIError,
  339. RateLimitError,
  340. ServiceUnavailableError,
  341. ) as e:
  342. logger.error(f'Completion Error occurred:\n{e}')
  343. raise
  344. finally:
  345. if kwargs.get('stream', False):
  346. await asyncio.sleep(0.1)
  347. self._async_completion = async_completion_wrapper # type: ignore
  348. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  349. async def _call_acompletion(self, *args, **kwargs):
  350. return await litellm.acompletion(*args, **kwargs)
  351. @property
  352. def completion(self):
  353. """Decorator for the litellm completion function.
  354. Check the complete documentation at https://litellm.vercel.app/docs/completion
  355. """
  356. try:
  357. return self._completion
  358. except Exception as e:
  359. raise LLMResponseError(e)
  360. @property
  361. def async_completion(self):
  362. """Decorator for the async litellm acompletion function.
  363. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  364. """
  365. try:
  366. return self._async_completion
  367. except Exception as e:
  368. raise LLMResponseError(e)
  369. @property
  370. def async_streaming_completion(self):
  371. """Decorator for the async litellm acompletion function with streaming.
  372. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  373. """
  374. try:
  375. return self._async_streaming_completion
  376. except Exception as e:
  377. raise LLMResponseError(e)
  378. def supports_vision(self):
  379. return litellm.supports_vision(self.config.model)
  380. def _post_completion(self, response) -> None:
  381. """Post-process the completion response."""
  382. try:
  383. cur_cost = self.completion_cost(response)
  384. except Exception:
  385. cur_cost = 0
  386. stats = ''
  387. if self.cost_metric_supported:
  388. stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
  389. cur_cost,
  390. self.metrics.accumulated_cost,
  391. )
  392. usage = response.get('usage')
  393. if usage:
  394. input_tokens = usage.get('prompt_tokens')
  395. output_tokens = usage.get('completion_tokens')
  396. if input_tokens:
  397. stats += 'Input tokens: ' + str(input_tokens) + '\n'
  398. if output_tokens:
  399. stats += 'Output tokens: ' + str(output_tokens) + '\n'
  400. model_extra = usage.get('model_extra', {})
  401. cache_creation_input_tokens = model_extra.get('cache_creation_input_tokens')
  402. if cache_creation_input_tokens:
  403. stats += (
  404. 'Input tokens (cache write): '
  405. + str(cache_creation_input_tokens)
  406. + '\n'
  407. )
  408. cache_read_input_tokens = model_extra.get('cache_read_input_tokens')
  409. if cache_read_input_tokens:
  410. stats += (
  411. 'Input tokens (cache read): ' + str(cache_read_input_tokens) + '\n'
  412. )
  413. if stats:
  414. logger.info(stats)
  415. def get_token_count(self, messages):
  416. """Get the number of tokens in a list of messages.
  417. Args:
  418. messages (list): A list of messages.
  419. Returns:
  420. int: The number of tokens.
  421. """
  422. return litellm.token_counter(model=self.config.model, messages=messages)
  423. def is_local(self):
  424. """Determines if the system is using a locally running LLM.
  425. Returns:
  426. boolean: True if executing a local model.
  427. """
  428. if self.config.base_url is not None:
  429. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  430. if substring in self.config.base_url:
  431. return True
  432. elif self.config.model is not None:
  433. if self.config.model.startswith('ollama'):
  434. return True
  435. return False
  436. def completion_cost(self, response):
  437. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  438. Add the current cost into total cost in metrics.
  439. Args:
  440. response: A response from a model invocation.
  441. Returns:
  442. number: The cost of the response.
  443. """
  444. if not self.cost_metric_supported:
  445. return 0.0
  446. extra_kwargs = {}
  447. if (
  448. self.config.input_cost_per_token is not None
  449. and self.config.output_cost_per_token is not None
  450. ):
  451. cost_per_token = CostPerToken(
  452. input_cost_per_token=self.config.input_cost_per_token,
  453. output_cost_per_token=self.config.output_cost_per_token,
  454. )
  455. logger.info(f'Using custom cost per token: {cost_per_token}')
  456. extra_kwargs['custom_cost_per_token'] = cost_per_token
  457. if not self.is_local():
  458. try:
  459. cost = litellm_completion_cost(
  460. completion_response=response, **extra_kwargs
  461. )
  462. self.metrics.add_cost(cost)
  463. return cost
  464. except Exception:
  465. self.cost_metric_supported = False
  466. logger.warning('Cost calculation not supported for this model.')
  467. return 0.0
  468. def __str__(self):
  469. if self.config.api_version:
  470. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  471. elif self.config.base_url:
  472. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  473. return f'LLM(model={self.config.model})'
  474. def __repr__(self):
  475. return str(self)
  476. def reset(self):
  477. self.metrics = Metrics()