reasoning.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import ast
  2. import logging
  3. import re
  4. import traceback
  5. from typing import Any, Optional
  6. import numpy as np
  7. from sympy import Rational
  8. from tasks.base import Task
  9. LOGGER = logging.getLogger('MINT')
  10. class ReasoningTask(Task):
  11. task_name = 'reasoning'
  12. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  13. super().__init__(**kwargs)
  14. self._id = id
  15. self._prompt = prompt.strip()
  16. self._reference = str(reference).strip().lower()
  17. def extract_answer(self, solution: str) -> Optional[str]:
  18. """Extract the answer from the given solution."""
  19. return solution.lower().strip()
  20. def compare_w_digits(self, reference: str, answer: str) -> bool:
  21. """Compare the reference and answer with digits."""
  22. # if reference can and answer can both be converted to floats by float()
  23. try:
  24. float(reference)
  25. float(answer)
  26. return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
  27. except ValueError:
  28. return reference in answer
  29. except Exception:
  30. raise ValueError(f'Cannot compare {reference} and {answer}')
  31. def success(self, solution: str) -> bool:
  32. answer = self.extract_answer(solution)
  33. return self.compare_w_digits(self._reference, answer)
  34. class MultipleChoiceTask(Task):
  35. """Subclass of Task for multiple choice tasks."""
  36. task_name = 'reasoning'
  37. def __init__(self, id, prompt: str, reference: str, **kwargs):
  38. super().__init__(**kwargs)
  39. self._id = id
  40. self.hide_options = kwargs.get('hide_options', False)
  41. if self.hide_options:
  42. self._prompt = prompt.split('Options:')[0].strip()
  43. else:
  44. self._prompt = prompt
  45. self._reference = reference.strip().lower()
  46. self._options = self.extract_options(prompt)
  47. # if all options can be converted to float, strictly perform hide options
  48. try:
  49. for option in self._options.values():
  50. float(option)
  51. self.hide_options = True
  52. except ValueError:
  53. pass
  54. self.metadata.update({'options': self._options})
  55. def extract_answer(self, solution: str) -> Optional[str]:
  56. # Extract the selected option from the solution
  57. solution = solution.lower().strip()
  58. for letter in 'abcdefghijklmnopqrstuvwxyz':
  59. if f'{letter})' in solution or f'{letter} )' in solution:
  60. print('SOLUTION', letter)
  61. return letter
  62. else:
  63. print('SOLUTION', solution)
  64. return solution
  65. def compare_w_digits(self, reference: str, answer: str) -> bool:
  66. if reference.isdigit() and answer.isdigit():
  67. return abs(float(reference) - float(answer)) <= 0.05 * float(reference)
  68. else:
  69. return reference in answer
  70. def success(self, solution: str) -> bool:
  71. answer = self.extract_answer(solution)
  72. if self.compare_w_digits(self._reference, answer):
  73. return True
  74. else:
  75. correct_option = self._options[self._reference]
  76. wrong_option_list = list(self._options.values())
  77. print('OPTIONS', correct_option, wrong_option_list)
  78. print('ANSWER', answer)
  79. for i in wrong_option_list:
  80. if i in correct_option:
  81. wrong_option_list.remove(i)
  82. for i in wrong_option_list:
  83. if self.compare_w_digits(i, answer) or (i in answer):
  84. return False
  85. if self.compare_w_digits(correct_option, answer) or (
  86. correct_option in answer
  87. ):
  88. return True
  89. else:
  90. return False
  91. def extract_options(self, prompt: str) -> dict:
  92. # Find the possible option separators (comma, semicolon, or parentheses)
  93. prompt = prompt.split('Options: ')[-1]
  94. # Extract the options using the delimiter
  95. options_match = prompt.split(' , ')
  96. options = {}
  97. for i in range(len(options_match)):
  98. option = options_match[i].strip("[]' ")
  99. option = option.split(')')
  100. letter = option[0].lower().strip()
  101. content = (
  102. option[1]
  103. .lower()
  104. .strip('.')
  105. .replace('. Which option is correct?', '')
  106. .replace('. Which one is correct?', '')
  107. .strip()
  108. )
  109. options.update({letter: content})
  110. return options
  111. # ==== TheoremQA ====
  112. def compare_two_numbers(p, gt):
  113. if isinstance(p, int) or isinstance(p, float):
  114. pass
  115. elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str):
  116. return False
  117. elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict):
  118. return False
  119. else:
  120. raise ValueError(p)
  121. if isinstance(gt, float):
  122. return within_eps(pred=p, gt=gt)
  123. else:
  124. return round(p) == gt
  125. def compare_two_list(pred, gt):
  126. if not isinstance(pred, list):
  127. return False
  128. elif len(pred) != len(gt):
  129. return False
  130. elif any([not isinstance(x, (int, float)) for x in pred]):
  131. return False
  132. else:
  133. pred = sorted(pred)
  134. gt = sorted(gt)
  135. return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)])
  136. def within_eps(pred: float, gt: float):
  137. eps = abs(gt) * 0.04
  138. if pred >= gt - eps and pred <= gt + eps:
  139. return True
  140. else:
  141. return False
  142. def parse_number_list(s: str):
  143. # Check if the string is a valid list by trying to parse it
  144. parsed_list = ast.literal_eval(s)
  145. return parsed_list
  146. def is_number(string):
  147. pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$'
  148. match = re.match(pattern, string)
  149. return bool(match)
  150. def is_scientific_number(string):
  151. pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$'
  152. match = re.match(pattern, string)
  153. return bool(match)
  154. def contain_num_and_str(string):
  155. pattern_str = r'[a-zA-Z]'
  156. pattern_num = r'[0-9]'
  157. return bool(re.search(pattern_str, string) and re.search(pattern_num, string))
  158. class TheoremqaTask(Task):
  159. task_name = 'reasoning'
  160. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  161. super().__init__(**kwargs)
  162. self._id = id
  163. self._prompt = (
  164. 'Answer the following question with a number, a list of numbers or True or False. '
  165. + prompt.strip()
  166. )
  167. self._reference = reference
  168. self._answer_type = kwargs.get('answer_type')
  169. def extract_answer(self, solution: str) -> Optional[Any]:
  170. """Extract the answer from the given solution."""
  171. prediction = solution
  172. # Following the preprocessing steps from TheoremQA
  173. # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170
  174. # Preprocessing the string [Stage 1]
  175. if not isinstance(prediction, str):
  176. prediction = str(prediction) if prediction is not None else '0'
  177. # Replace special tokens
  178. if '=' in prediction:
  179. prediction = prediction.split('=')[-1].strip()
  180. if '≈' in prediction:
  181. prediction = prediction.split('≈')[-1].strip()
  182. if '`' in prediction:
  183. prediction = prediction.replace('`', '')
  184. if '$' in prediction:
  185. prediction = prediction.replace('$', '')
  186. if '°' in prediction:
  187. prediction = prediction.replace('°', '')
  188. # Detect the boolean keyword in the generation
  189. if prediction in ['true', 'yes', 'false', 'no']:
  190. if prediction == 'true' or prediction == 'yes':
  191. prediction = 'True'
  192. else:
  193. prediction = 'False'
  194. if 'True' in prediction or 'False' in prediction:
  195. prediction = 'True' if 'True' in prediction else 'False'
  196. # Detect the approximation keyword
  197. if 'approximately' in prediction:
  198. prediction = prediction.replace('approximately', '').strip()
  199. if ' or ' in prediction:
  200. prediction = prediction.split(' or ')[0]
  201. # Drop the units before and after the number
  202. if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction):
  203. prediction = re.search(
  204. r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction
  205. ).group(1)
  206. if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction):
  207. prediction = re.search(
  208. r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction
  209. ).group(1)
  210. if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction):
  211. prediction = re.search(
  212. r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction
  213. ).group(1)
  214. if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction):
  215. prediction = re.search(
  216. r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction
  217. ).group(1)
  218. # Preprocessing the number [Stage 1]
  219. if '10^' in prediction:
  220. prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction)
  221. if ' x ' in prediction:
  222. prediction = prediction.replace(' x ', '*')
  223. if ' × ' in prediction:
  224. prediction = prediction.replace(' × ', '*')
  225. if is_number(prediction):
  226. prediction = prediction.replace(',', '')
  227. # Preprocessing the option [Stage 3]
  228. if (
  229. 'a)' in prediction
  230. or 'a )' in prediction
  231. or prediction.lower().strip() == 'a'
  232. ):
  233. prediction = '(a)'
  234. if (
  235. 'b)' in prediction
  236. or 'b )' in prediction
  237. or prediction.lower().strip() == 'b'
  238. ):
  239. prediction = '(b)'
  240. if (
  241. 'c)' in prediction
  242. or 'c )' in prediction
  243. or prediction.lower().strip() == 'c'
  244. ):
  245. prediction = '(c)'
  246. if (
  247. 'd)' in prediction
  248. or 'd )' in prediction
  249. or prediction.lower().strip() == 'd'
  250. ):
  251. prediction = '(d)'
  252. if (
  253. '(a)' in prediction
  254. or '(b)' in prediction
  255. or '(c)' in prediction
  256. or '(d)' in prediction
  257. ):
  258. prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"'
  259. # If the prediction is empty, use dummy '0'
  260. if not prediction:
  261. prediction = '0'
  262. # Converting the string answer to a number/list/bool/option
  263. try:
  264. prediction = eval(prediction)
  265. except Exception:
  266. LOGGER.warning(
  267. f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}'
  268. )
  269. return None # failed to convert the answer
  270. # Performing common type conversion
  271. if isinstance(prediction, (set, tuple)):
  272. prediction = list(prediction)
  273. if isinstance(prediction[0], complex):
  274. prediction = [tmp.real for tmp in prediction]
  275. elif isinstance(prediction[0], Rational):
  276. prediction = [float(tmp) for tmp in prediction]
  277. elif isinstance(prediction, np.ndarray):
  278. prediction = prediction.tolist()
  279. else:
  280. if isinstance(prediction, complex):
  281. prediction = prediction.real
  282. elif isinstance(prediction, Rational):
  283. prediction = float(prediction)
  284. return prediction
  285. def success(self, solution: str) -> bool:
  286. """This checks whether the given solution can complete the current task."""
  287. # Follow the implementation from TheoremQA
  288. # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1
  289. prediction = self.extract_answer(solution)
  290. LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}')
  291. answer_type = self._answer_type
  292. gt = self.extract_answer(self.reference)
  293. if isinstance(prediction, (str, int, float)) or isinstance(prediction, list):
  294. # Comparing prediction against the reference
  295. if answer_type in ['bool', 'option', 'Option']:
  296. cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
  297. elif answer_type == 'integer':
  298. cur_correct = int(compare_two_numbers(prediction, gt))
  299. elif answer_type == 'float':
  300. cur_correct = int(compare_two_numbers(prediction, gt))
  301. elif answer_type in ['list of integer', 'list of float']:
  302. cur_correct = int(compare_two_list(prediction, gt))
  303. else:
  304. cur_correct = 0
  305. return bool(cur_correct)