lm_utils.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import os
  2. import sys
  3. import time
  4. from openai import OpenAI
  5. from tenacity import (
  6. retry,
  7. stop_after_attempt, # type: ignore
  8. wait_random_exponential, # type: ignore
  9. )
  10. if sys.version_info >= (3, 8):
  11. from typing import Literal
  12. else:
  13. from typing_extensions import Literal
  14. Model = Literal['gpt-4', 'gpt-3.5-turbo', 'text-davinci-003']
  15. OpenAI.api_key = os.getenv('OPENAI_API_KEY')
  16. OPENAI_GEN_HYP = {
  17. 'temperature': 0,
  18. 'max_tokens': 250,
  19. 'top_p': 1.0,
  20. 'frequency_penalty': 0,
  21. 'presence_penalty': 0,
  22. }
  23. @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
  24. def run_chatgpt_query_multi_turn(
  25. messages,
  26. model_name='gpt-4-turbo', # pass "gpt4" for more recent model output
  27. max_tokens=256,
  28. temperature=0.0,
  29. json_response=False,
  30. ):
  31. response = None
  32. num_retries = 3
  33. retry = 0
  34. while retry < num_retries:
  35. retry += 1
  36. try:
  37. client = OpenAI()
  38. if json_response:
  39. response = client.chat.completions.create(
  40. model=model_name,
  41. response_format={'type': 'json_object'},
  42. messages=messages,
  43. **OPENAI_GEN_HYP,
  44. )
  45. else:
  46. response = client.chat.completions.create(
  47. model=model_name, messages=messages, **OPENAI_GEN_HYP
  48. )
  49. break
  50. except Exception as e:
  51. print(e)
  52. print('GPT error. Retrying in 2 seconds...')
  53. time.sleep(2)
  54. return response