llm.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. import asyncio
  2. import copy
  3. import warnings
  4. from functools import partial
  5. from openhands.core.config import LLMConfig
  6. with warnings.catch_warnings():
  7. warnings.simplefilter('ignore')
  8. import litellm
  9. from litellm import completion as litellm_completion
  10. from litellm import completion_cost as litellm_completion_cost
  11. from litellm.exceptions import (
  12. APIConnectionError,
  13. ContentPolicyViolationError,
  14. InternalServerError,
  15. OpenAIError,
  16. RateLimitError,
  17. ServiceUnavailableError,
  18. )
  19. from litellm.types.utils import CostPerToken
  20. from tenacity import (
  21. retry,
  22. retry_if_exception_type,
  23. stop_after_attempt,
  24. wait_random_exponential,
  25. )
  26. from openhands.core.exceptions import UserCancelledError
  27. from openhands.core.logger import llm_prompt_logger, llm_response_logger
  28. from openhands.core.logger import openhands_logger as logger
  29. from openhands.core.metrics import Metrics
  30. __all__ = ['LLM']
  31. message_separator = '\n\n----------\n\n'
  32. class LLM:
  33. """The LLM class represents a Language Model instance.
  34. Attributes:
  35. config: an LLMConfig object specifying the configuration of the LLM.
  36. """
  37. def __init__(
  38. self,
  39. config: LLMConfig,
  40. metrics: Metrics | None = None,
  41. ):
  42. """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
  43. Passing simple parameters always overrides config.
  44. Args:
  45. config: The LLM configuration
  46. """
  47. self.config = copy.deepcopy(config)
  48. self.metrics = metrics if metrics is not None else Metrics()
  49. self.cost_metric_supported = True
  50. # Set up config attributes with default values to prevent AttributeError
  51. LLMConfig.set_missing_attributes(self.config)
  52. # litellm actually uses base Exception here for unknown model
  53. self.model_info = None
  54. try:
  55. if self.config.model.startswith('openrouter'):
  56. self.model_info = litellm.get_model_info(self.config.model)
  57. else:
  58. self.model_info = litellm.get_model_info(
  59. self.config.model.split(':')[0]
  60. )
  61. # noinspection PyBroadException
  62. except Exception as e:
  63. logger.warning(f'Could not get model info for {config.model}:\n{e}')
  64. # Set the max tokens in an LM-specific way if not set
  65. if self.config.max_input_tokens is None:
  66. if (
  67. self.model_info is not None
  68. and 'max_input_tokens' in self.model_info
  69. and isinstance(self.model_info['max_input_tokens'], int)
  70. ):
  71. self.config.max_input_tokens = self.model_info['max_input_tokens']
  72. else:
  73. # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  74. self.config.max_input_tokens = 4096
  75. if self.config.max_output_tokens is None:
  76. if (
  77. self.model_info is not None
  78. and 'max_output_tokens' in self.model_info
  79. and isinstance(self.model_info['max_output_tokens'], int)
  80. ):
  81. self.config.max_output_tokens = self.model_info['max_output_tokens']
  82. else:
  83. # Max output tokens for gpt3.5, so this is a safe fallback for any potentially viable model
  84. self.config.max_output_tokens = 1024
  85. if self.config.drop_params:
  86. litellm.drop_params = self.config.drop_params
  87. self._completion = partial(
  88. litellm_completion,
  89. model=self.config.model,
  90. api_key=self.config.api_key,
  91. base_url=self.config.base_url,
  92. api_version=self.config.api_version,
  93. custom_llm_provider=self.config.custom_llm_provider,
  94. max_tokens=self.config.max_output_tokens,
  95. timeout=self.config.timeout,
  96. temperature=self.config.temperature,
  97. top_p=self.config.top_p,
  98. )
  99. completion_unwrapped = self._completion
  100. def attempt_on_error(retry_state):
  101. logger.error(
  102. f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize these settings in the configuration.',
  103. exc_info=False,
  104. )
  105. return None
  106. @retry(
  107. reraise=True,
  108. stop=stop_after_attempt(self.config.num_retries),
  109. wait=wait_random_exponential(
  110. multiplier=self.config.retry_multiplier,
  111. min=self.config.retry_min_wait,
  112. max=self.config.retry_max_wait,
  113. ),
  114. retry=retry_if_exception_type(
  115. (
  116. RateLimitError,
  117. APIConnectionError,
  118. ServiceUnavailableError,
  119. InternalServerError,
  120. ContentPolicyViolationError,
  121. )
  122. ),
  123. after=attempt_on_error,
  124. )
  125. def wrapper(*args, **kwargs):
  126. """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
  127. # some callers might just send the messages directly
  128. if 'messages' in kwargs:
  129. messages = kwargs['messages']
  130. else:
  131. messages = args[1]
  132. # log the prompt
  133. debug_message = ''
  134. for message in messages:
  135. content = message['content']
  136. if isinstance(content, list):
  137. for element in content:
  138. if isinstance(element, dict):
  139. if 'text' in element:
  140. content_str = element['text'].strip()
  141. elif (
  142. 'image_url' in element and 'url' in element['image_url']
  143. ):
  144. content_str = element['image_url']['url']
  145. else:
  146. content_str = str(element)
  147. else:
  148. content_str = str(element)
  149. debug_message += message_separator + content_str
  150. else:
  151. content_str = str(content)
  152. debug_message += message_separator + content_str
  153. llm_prompt_logger.debug(debug_message)
  154. # skip if messages is empty (thus debug_message is empty)
  155. if debug_message:
  156. resp = completion_unwrapped(*args, **kwargs)
  157. else:
  158. resp = {'choices': [{'message': {'content': ''}}]}
  159. # log the response
  160. message_back = resp['choices'][0]['message']['content']
  161. llm_response_logger.debug(message_back)
  162. # post-process to log costs
  163. self._post_completion(resp)
  164. return resp
  165. self._completion = wrapper # type: ignore
  166. # Async version
  167. self._async_completion = partial(
  168. self._call_acompletion,
  169. model=self.config.model,
  170. api_key=self.config.api_key,
  171. base_url=self.config.base_url,
  172. api_version=self.config.api_version,
  173. custom_llm_provider=self.config.custom_llm_provider,
  174. max_tokens=self.config.max_output_tokens,
  175. timeout=self.config.timeout,
  176. temperature=self.config.temperature,
  177. top_p=self.config.top_p,
  178. drop_params=True,
  179. )
  180. async_completion_unwrapped = self._async_completion
  181. @retry(
  182. reraise=True,
  183. stop=stop_after_attempt(self.config.num_retries),
  184. wait=wait_random_exponential(
  185. multiplier=self.config.retry_multiplier,
  186. min=self.config.retry_min_wait,
  187. max=self.config.retry_max_wait,
  188. ),
  189. retry=retry_if_exception_type(
  190. (
  191. RateLimitError,
  192. APIConnectionError,
  193. ServiceUnavailableError,
  194. InternalServerError,
  195. ContentPolicyViolationError,
  196. )
  197. ),
  198. after=attempt_on_error,
  199. )
  200. async def async_completion_wrapper(*args, **kwargs):
  201. """Async wrapper for the litellm acompletion function."""
  202. # some callers might just send the messages directly
  203. if 'messages' in kwargs:
  204. messages = kwargs['messages']
  205. else:
  206. messages = args[1]
  207. # log the prompt
  208. debug_message = ''
  209. for message in messages:
  210. content = message['content']
  211. if isinstance(content, list):
  212. for element in content:
  213. if isinstance(element, dict):
  214. if 'text' in element:
  215. content_str = element['text']
  216. elif (
  217. 'image_url' in element and 'url' in element['image_url']
  218. ):
  219. content_str = element['image_url']['url']
  220. else:
  221. content_str = str(element)
  222. else:
  223. content_str = str(element)
  224. debug_message += message_separator + content_str
  225. else:
  226. content_str = str(content)
  227. debug_message += message_separator + content_str
  228. llm_prompt_logger.debug(debug_message)
  229. async def check_stopped():
  230. while True:
  231. if (
  232. hasattr(self.config, 'on_cancel_requested_fn')
  233. and self.config.on_cancel_requested_fn is not None
  234. and await self.config.on_cancel_requested_fn()
  235. ):
  236. raise UserCancelledError('LLM request cancelled by user')
  237. await asyncio.sleep(0.1)
  238. stop_check_task = asyncio.create_task(check_stopped())
  239. try:
  240. # Directly call and await litellm_acompletion
  241. resp = await async_completion_unwrapped(*args, **kwargs)
  242. # skip if messages is empty (thus debug_message is empty)
  243. if debug_message:
  244. message_back = resp['choices'][0]['message']['content']
  245. llm_response_logger.debug(message_back)
  246. else:
  247. resp = {'choices': [{'message': {'content': ''}}]}
  248. self._post_completion(resp)
  249. # We do not support streaming in this method, thus return resp
  250. return resp
  251. except UserCancelledError:
  252. logger.info('LLM request cancelled by user.')
  253. raise
  254. except OpenAIError as e:
  255. logger.error(f'OpenAIError occurred:\n{e}')
  256. raise
  257. except (
  258. RateLimitError,
  259. APIConnectionError,
  260. ServiceUnavailableError,
  261. InternalServerError,
  262. ) as e:
  263. logger.error(f'Completion Error occurred:\n{e}')
  264. raise
  265. finally:
  266. await asyncio.sleep(0.1)
  267. stop_check_task.cancel()
  268. try:
  269. await stop_check_task
  270. except asyncio.CancelledError:
  271. pass
  272. @retry(
  273. reraise=True,
  274. stop=stop_after_attempt(self.config.num_retries),
  275. wait=wait_random_exponential(
  276. multiplier=self.config.retry_multiplier,
  277. min=self.config.retry_min_wait,
  278. max=self.config.retry_max_wait,
  279. ),
  280. retry=retry_if_exception_type(
  281. (
  282. RateLimitError,
  283. APIConnectionError,
  284. ServiceUnavailableError,
  285. InternalServerError,
  286. ContentPolicyViolationError,
  287. )
  288. ),
  289. after=attempt_on_error,
  290. )
  291. async def async_acompletion_stream_wrapper(*args, **kwargs):
  292. """Async wrapper for the litellm acompletion with streaming function."""
  293. # some callers might just send the messages directly
  294. if 'messages' in kwargs:
  295. messages = kwargs['messages']
  296. else:
  297. messages = args[1]
  298. # log the prompt
  299. debug_message = ''
  300. for message in messages:
  301. debug_message += message_separator + message['content']
  302. llm_prompt_logger.debug(debug_message)
  303. try:
  304. # Directly call and await litellm_acompletion
  305. resp = await async_completion_unwrapped(*args, **kwargs)
  306. # For streaming we iterate over the chunks
  307. async for chunk in resp:
  308. # Check for cancellation before yielding the chunk
  309. if (
  310. hasattr(self.config, 'on_cancel_requested_fn')
  311. and self.config.on_cancel_requested_fn is not None
  312. and await self.config.on_cancel_requested_fn()
  313. ):
  314. raise UserCancelledError(
  315. 'LLM request cancelled due to CANCELLED state'
  316. )
  317. # with streaming, it is "delta", not "message"!
  318. message_back = chunk['choices'][0]['delta']['content']
  319. llm_response_logger.debug(message_back)
  320. self._post_completion(chunk)
  321. yield chunk
  322. except UserCancelledError:
  323. logger.info('LLM request cancelled by user.')
  324. raise
  325. except OpenAIError as e:
  326. logger.error(f'OpenAIError occurred:\n{e}')
  327. raise
  328. except (
  329. RateLimitError,
  330. APIConnectionError,
  331. ServiceUnavailableError,
  332. InternalServerError,
  333. ) as e:
  334. logger.error(f'Completion Error occurred:\n{e}')
  335. raise
  336. finally:
  337. if kwargs.get('stream', False):
  338. await asyncio.sleep(0.1)
  339. self._async_completion = async_completion_wrapper # type: ignore
  340. self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
  341. async def _call_acompletion(self, *args, **kwargs):
  342. return await litellm.acompletion(*args, **kwargs)
  343. @property
  344. def completion(self):
  345. """Decorator for the litellm completion function.
  346. Check the complete documentation at https://litellm.vercel.app/docs/completion
  347. """
  348. return self._completion
  349. @property
  350. def async_completion(self):
  351. """Decorator for the async litellm acompletion function.
  352. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  353. """
  354. return self._async_completion
  355. @property
  356. def async_streaming_completion(self):
  357. """Decorator for the async litellm acompletion function with streaming.
  358. Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
  359. """
  360. return self._async_streaming_completion
  361. def supports_vision(self):
  362. return litellm.supports_vision(self.config.model)
  363. def _post_completion(self, response: str) -> None:
  364. """Post-process the completion response."""
  365. try:
  366. cur_cost = self.completion_cost(response)
  367. except Exception:
  368. cur_cost = 0
  369. if self.cost_metric_supported:
  370. logger.info(
  371. 'Cost: %.2f USD | Accumulated Cost: %.2f USD',
  372. cur_cost,
  373. self.metrics.accumulated_cost,
  374. )
  375. def get_token_count(self, messages):
  376. """Get the number of tokens in a list of messages.
  377. Args:
  378. messages (list): A list of messages.
  379. Returns:
  380. int: The number of tokens.
  381. """
  382. return litellm.token_counter(model=self.config.model, messages=messages)
  383. def is_local(self):
  384. """Determines if the system is using a locally running LLM.
  385. Returns:
  386. boolean: True if executing a local model.
  387. """
  388. if self.config.base_url is not None:
  389. for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
  390. if substring in self.config.base_url:
  391. return True
  392. elif self.config.model is not None:
  393. if self.config.model.startswith('ollama'):
  394. return True
  395. return False
  396. def completion_cost(self, response):
  397. """Calculate the cost of a completion response based on the model. Local models are treated as free.
  398. Add the current cost into total cost in metrics.
  399. Args:
  400. response: A response from a model invocation.
  401. Returns:
  402. number: The cost of the response.
  403. """
  404. if not self.cost_metric_supported:
  405. return 0.0
  406. extra_kwargs = {}
  407. if (
  408. self.config.input_cost_per_token is not None
  409. and self.config.output_cost_per_token is not None
  410. ):
  411. cost_per_token = CostPerToken(
  412. input_cost_per_token=self.config.input_cost_per_token,
  413. output_cost_per_token=self.config.output_cost_per_token,
  414. )
  415. logger.info(f'Using custom cost per token: {cost_per_token}')
  416. extra_kwargs['custom_cost_per_token'] = cost_per_token
  417. if not self.is_local():
  418. try:
  419. cost = litellm_completion_cost(
  420. completion_response=response, **extra_kwargs
  421. )
  422. self.metrics.add_cost(cost)
  423. return cost
  424. except Exception:
  425. self.cost_metric_supported = False
  426. logger.warning('Cost calculation not supported for this model.')
  427. return 0.0
  428. def __str__(self):
  429. if self.config.api_version:
  430. return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
  431. elif self.config.base_url:
  432. return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
  433. return f'LLM(model={self.config.model})'
  434. def __repr__(self):
  435. return str(self)
  436. def reset(self):
  437. self.metrics = Metrics()