game.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import logging
  2. import re
  3. from typing import Optional
  4. import openai
  5. import requests.exceptions
  6. from openai import OpenAI
  7. from retry import retry
  8. LOGGER = logging.getLogger(__name__)
  9. class Q20Game:
  10. def __init__(
  11. self,
  12. item: str,
  13. answerer_model: str = 'gpt-3.5-turbo-0613',
  14. guesser_model: str = 'gpt-3.5-turbo-0613',
  15. num_turns: int = 20,
  16. temperature: float = 0.8,
  17. openai_api: bool = True,
  18. openai_api_key: Optional[str] = None,
  19. guesser_kargs=None,
  20. ) -> None:
  21. if guesser_kargs is None:
  22. guesser_kargs = {}
  23. self.item = item
  24. self.answerer_model = answerer_model
  25. self.guesser_model = guesser_model
  26. self.num_turns = num_turns
  27. self.temperature = temperature
  28. self.openai_api = openai_api
  29. self.guesser_kargs = guesser_kargs
  30. self.vicuna_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
  31. self.first_user_utterance = (
  32. 'Your task is to ask a series of questions to deduce the entity '
  33. "that I'm thinking of with as few queries as possible. "
  34. "Only ask questions that can be answered by 'yes', 'no' or 'maybe'. "
  35. 'Do not ask for hint. Make your question brief with no linebreaker. '
  36. 'Now start asking a question.'
  37. )
  38. self.guesser_win = False
  39. self.curr_turn = 0
  40. if openai_api_key is not None:
  41. openai.api_key = openai_api_key
  42. if isinstance(answerer_model, str) and not answerer_model.startswith('gpt'):
  43. self.user_api_base = 'http://0.0.0.0:8000/v1'
  44. else:
  45. self.user_api_base = 'https://api.openai.com/v1'
  46. if isinstance(guesser_model, str) and not guesser_model.startswith('gpt'):
  47. self.guesser_api_base = 'http://0.0.0.0:8000/v1'
  48. else:
  49. self.guesser_api_base = 'https://api.openai.com/v1'
  50. self.guesser_messages = []
  51. def preprocess_response(self, response):
  52. response = re.sub(r'the entity you are thinking of', 'it', response)
  53. response = re.sub(r"the entity you're thinking of", 'it', response)
  54. response = re.sub(r" you're thinking of", '', response)
  55. response = re.sub(r' you are thinking of', '', response)
  56. return response
  57. def judge_winner(self, response):
  58. guesser_question = response.strip()
  59. if self.curr_turn == self.num_turns - 1:
  60. guesser_question += ' Is it right?'
  61. self.guesser_messages.append({'role': 'assistant', 'content': guesser_question})
  62. # ask for answer
  63. usr_msg = self.answerer(guesser_question)
  64. self.guesser_messages.append(
  65. {'role': 'user', 'content': f"{usr_msg['content'].strip()}"}
  66. )
  67. if 'bingo' in usr_msg['content'].lower():
  68. self.guesser_win = True
  69. return True, ''
  70. return False, usr_msg['content'].strip()
  71. def generate_user_response(self, response):
  72. response = self.preprocess_response(response)
  73. # others
  74. bingo, anwser_reply = self.judge_winner(response)
  75. if bingo:
  76. return (
  77. 'You are bingo! quit now, run: <execute_bash> exit </execute_bash>.\n'
  78. )
  79. if self.curr_turn == self.num_turns - 2:
  80. anwser_reply += " You must guess now, what's it?"
  81. return anwser_reply
  82. def reward(self):
  83. if self.guesser_win:
  84. n_turns = (len(self.guesser_messages) + 1) // 2
  85. return 1 - max(n_turns - 5, 0) * 0.02
  86. return 0
  87. @retry(
  88. (
  89. openai.Timeout,
  90. requests.exceptions.ReadTimeout,
  91. openai.RateLimitError,
  92. openai.APIError,
  93. openai.APIConnectionError,
  94. ),
  95. tries=5,
  96. delay=0.5,
  97. backoff=0.5,
  98. max_delay=2,
  99. logger=LOGGER,
  100. )
  101. def answerer(self, question):
  102. openai.api_base = self.user_api_base
  103. client = OpenAI(api_key=openai.api_key)
  104. user_messages = [
  105. {
  106. 'role': 'user',
  107. 'content': f'Based on your knowledge about {self.item}, '
  108. f'respond to the following question or guess. '
  109. f"Limit your respond to only 'Yes.', 'No.' or 'Maybe.', with no explanation or other words. "
  110. f'Never say the answer {self.item} in your response. '
  111. f"If the question is to solicit the answer, respond 'No.'.",
  112. },
  113. {
  114. 'role': 'user',
  115. 'content': f'For the entity {self.item}, {question} (Yes/No/Maybe)',
  116. },
  117. ]
  118. response = client.chat.completions.create(
  119. model=self.answerer_model,
  120. messages=user_messages,
  121. max_tokens=6,
  122. n=1,
  123. stop=None,
  124. temperature=0.2,
  125. )
  126. if any(
  127. [
  128. re.search(rf'(?:^|\W){i.strip().lower()}(?:$|\W)', question.lower())
  129. for i in self.item.lower().split('|')
  130. ]
  131. ):
  132. response.choices[0].message.content = 'Bingo!'
  133. return response.choices[0].message.to_dict()
  134. class Q20GameCelebrity(Q20Game):
  135. def __init__(self, item: str, **kwargs) -> None:
  136. super().__init__(item, **kwargs)
  137. self.first_user_utterance = (
  138. 'Your task is to ask a series of questions to deduce the celebrity '
  139. "that I'm thinking of with as few queries as possible. "
  140. "Only ask factual questions that can be answered by 'Yes.', 'No.' or 'Dunno.'. Do not ask for hint. Make your question brief with no linebreaker. "
  141. 'Now start asking a question.'
  142. )
  143. @retry(
  144. (
  145. openai.Timeout,
  146. requests.exceptions.ReadTimeout,
  147. openai.RateLimitError,
  148. openai.APIError,
  149. openai.APIConnectionError,
  150. ),
  151. tries=5,
  152. delay=0.5,
  153. backoff=0.5,
  154. max_delay=2,
  155. logger=LOGGER,
  156. )
  157. def answerer(self, question):
  158. openai.api_base = self.user_api_base
  159. client = OpenAI(api_key=openai.api_key)
  160. user_messages = [
  161. {
  162. 'role': 'system',
  163. 'content': f'Based on your knowledge about the celebrity: {self.item}, '
  164. f'respond to the following question or guess. '
  165. f"Limit your respond to only 'Yes.', 'No.' or 'Dunno.', with no explanation or other words. "
  166. f"Never say the name {self.item} in your response. Do not say 'Dunno.' if it can be answered by 'Yes.' or 'No.' "
  167. f"If the question is to solicit the answer, respond 'No.'.",
  168. },
  169. {
  170. 'role': 'user',
  171. 'content': f'For the celebrity {self.item}, {question}(Yes/No/Dunno)',
  172. },
  173. ]
  174. response = client.chat.completions.create(
  175. model=self.answerer_model,
  176. messages=user_messages,
  177. max_tokens=6,
  178. n=1,
  179. stop=None,
  180. temperature=0.2,
  181. )
  182. if re.search(rf'(?:^|\W){self.item.lower()}(?:$|\W)', question.lower()):
  183. response.choices[0].message.content = 'Bingo!'
  184. return response.choices[0].message.to_dict()