reasoning.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import ast
  2. import logging
  3. import re
  4. import traceback
  5. from typing import Any
  6. import numpy as np
  7. from sympy import Rational
  8. from tasks.base import Task
  9. LOGGER = logging.getLogger('MINT')
  10. class ReasoningTask(Task):
  11. task_name = 'reasoning'
  12. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  13. super().__init__(**kwargs)
  14. self._id = id
  15. self._prompt = prompt.strip()
  16. self._reference = str(reference).strip().lower()
  17. def extract_answer(self, solution: str) -> str | None:
  18. """Extract the answer from the given solution."""
  19. return solution.lower().strip()
  20. def compare_w_digits(self, reference: str, answer: str) -> bool:
  21. """Compare the reference and answer with digits."""
  22. # if reference can and answer can both be converted to floats by float()
  23. try:
  24. float(reference)
  25. float(answer)
  26. return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
  27. except ValueError:
  28. return reference in answer
  29. except Exception:
  30. raise ValueError(f'Cannot compare {reference} and {answer}')
  31. def success(self, solution: str) -> bool:
  32. answer = self.extract_answer(solution)
  33. return self.compare_w_digits(self._reference, answer)
  34. class MultipleChoiceTask(Task):
  35. """Subclass of Task for multiple choice tasks."""
  36. task_name = 'reasoning'
  37. def __init__(self, id, prompt: str, reference: str, **kwargs):
  38. super().__init__(**kwargs)
  39. self._id = id
  40. self.hide_options = kwargs.get('hide_options', False)
  41. if self.hide_options:
  42. self._prompt = prompt.split('Options:')[0].strip()
  43. else:
  44. self._prompt = prompt
  45. self._reference = reference.strip().lower()
  46. self._options = self.extract_options(prompt)
  47. # if all options can be converted to float, strictly perform hide options
  48. try:
  49. for option in self._options.values():
  50. float(option)
  51. self.hide_options = True
  52. except ValueError:
  53. pass
  54. self.metadata.update({'options': self._options})
  55. def extract_answer(self, solution: str) -> str | None:
  56. # Extract the selected option from the solution
  57. solution = solution.lower().strip()
  58. for letter in 'abcdefghijklmnopqrstuvwxyz':
  59. if f'{letter})' in solution or f'{letter} )' in solution:
  60. print('SOLUTION', letter)
  61. return letter
  62. else:
  63. print('SOLUTION', solution)
  64. return solution
  65. def compare_w_digits(self, reference: str, answer: str) -> bool:
  66. if reference.isdigit() and answer.isdigit():
  67. return abs(float(reference) - float(answer)) <= 0.05 * float(reference)
  68. else:
  69. return reference in answer
  70. def success(self, solution: str) -> bool:
  71. answer = self.extract_answer(solution)
  72. if self.compare_w_digits(self._reference, answer):
  73. return True
  74. else:
  75. correct_option = self._options[self._reference]
  76. wrong_option_list = list(self._options.values())
  77. print('OPTIONS', correct_option, wrong_option_list)
  78. print('ANSWER', answer)
  79. for i in wrong_option_list:
  80. if i in correct_option:
  81. wrong_option_list.remove(i)
  82. for i in wrong_option_list:
  83. if self.compare_w_digits(i, answer) or (i in answer):
  84. return False
  85. if self.compare_w_digits(correct_option, answer) or (
  86. correct_option in answer
  87. ):
  88. return True
  89. else:
  90. return False
  91. def extract_options(self, prompt: str) -> dict:
  92. # Find the possible option separators (comma, semicolon, or parentheses)
  93. prompt = prompt.split('Options: ')[-1]
  94. # Extract the options using the delimiter
  95. options_match = prompt.split(' , ')
  96. options = {}
  97. for i in range(len(options_match)):
  98. option = options_match[i].strip("[]' ")
  99. option = option.split(')')
  100. letter = option[0].lower().strip()
  101. content = (
  102. option[1]
  103. .lower()
  104. .strip('.')
  105. .replace('. Which option is correct?', '')
  106. .replace('. Which one is correct?', '')
  107. .strip()
  108. )
  109. options.update({letter: content})
  110. return options
  111. # ==== TheoremQA ====
  112. def compare_two_numbers(p, gt):
  113. if isinstance(p, (int, float)):
  114. pass
  115. elif isinstance(p, (bool, complex, dict, list, str, tuple)):
  116. return False
  117. else:
  118. raise ValueError(p)
  119. if isinstance(gt, float):
  120. return within_eps(pred=p, gt=gt)
  121. else:
  122. return round(p) == gt
  123. def compare_two_list(pred, gt):
  124. if not isinstance(pred, list):
  125. return False
  126. elif len(pred) != len(gt):
  127. return False
  128. elif any([not isinstance(x, (int, float)) for x in pred]):
  129. return False
  130. else:
  131. pred = sorted(pred)
  132. gt = sorted(gt)
  133. return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)])
  134. def within_eps(pred: float, gt: float):
  135. eps = abs(gt) * 0.04
  136. if pred >= gt - eps and pred <= gt + eps:
  137. return True
  138. else:
  139. return False
  140. def parse_number_list(s: str):
  141. # Check if the string is a valid list by trying to parse it
  142. parsed_list = ast.literal_eval(s)
  143. return parsed_list
  144. def is_number(string):
  145. pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$'
  146. match = re.match(pattern, string)
  147. return bool(match)
  148. def is_scientific_number(string):
  149. pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$'
  150. match = re.match(pattern, string)
  151. return bool(match)
  152. def contain_num_and_str(string):
  153. pattern_str = r'[a-zA-Z]'
  154. pattern_num = r'[0-9]'
  155. return bool(re.search(pattern_str, string) and re.search(pattern_num, string))
  156. class TheoremqaTask(Task):
  157. task_name = 'reasoning'
  158. def __init__(self, id: str, prompt: str, reference: str, **kwargs):
  159. super().__init__(**kwargs)
  160. self._id = id
  161. self._prompt = (
  162. 'Answer the following question with a number, a list of numbers or True or False. '
  163. + prompt.strip()
  164. )
  165. self._reference = reference
  166. self._answer_type = kwargs.get('answer_type')
  167. def extract_answer(self, solution: str) -> Any:
  168. """Extract the answer from the given solution."""
  169. prediction = solution
  170. # Following the preprocessing steps from TheoremQA
  171. # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170
  172. # Preprocessing the string [Stage 1]
  173. if not isinstance(prediction, str):
  174. prediction = str(prediction) if prediction is not None else '0'
  175. # Replace special tokens
  176. if '=' in prediction:
  177. prediction = prediction.split('=')[-1].strip()
  178. if '≈' in prediction:
  179. prediction = prediction.split('≈')[-1].strip()
  180. if '`' in prediction:
  181. prediction = prediction.replace('`', '')
  182. if '$' in prediction:
  183. prediction = prediction.replace('$', '')
  184. if '°' in prediction:
  185. prediction = prediction.replace('°', '')
  186. # Detect the boolean keyword in the generation
  187. if prediction in ('true', 'yes', 'false', 'no'):
  188. if prediction in ('true', 'yes'):
  189. prediction = 'True'
  190. else:
  191. prediction = 'False'
  192. if 'True' in prediction or 'False' in prediction:
  193. prediction = 'True' if 'True' in prediction else 'False'
  194. # Detect the approximation keyword
  195. if 'approximately' in prediction:
  196. prediction = prediction.replace('approximately', '').strip()
  197. if ' or ' in prediction:
  198. prediction = prediction.split(' or ')[0]
  199. # Drop the units before and after the number
  200. if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction):
  201. prediction = re.search(
  202. r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction
  203. ).group(1)
  204. if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction):
  205. prediction = re.search(
  206. r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction
  207. ).group(1)
  208. if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction):
  209. prediction = re.search(
  210. r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction
  211. ).group(1)
  212. if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction):
  213. prediction = re.search(
  214. r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction
  215. ).group(1)
  216. # Preprocessing the number [Stage 1]
  217. if '10^' in prediction:
  218. prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction)
  219. if ' x ' in prediction:
  220. prediction = prediction.replace(' x ', '*')
  221. if ' × ' in prediction:
  222. prediction = prediction.replace(' × ', '*')
  223. if is_number(prediction):
  224. prediction = prediction.replace(',', '')
  225. # Preprocessing the option [Stage 3]
  226. if (
  227. 'a)' in prediction
  228. or 'a )' in prediction
  229. or prediction.lower().strip() == 'a'
  230. ):
  231. prediction = '(a)'
  232. if (
  233. 'b)' in prediction
  234. or 'b )' in prediction
  235. or prediction.lower().strip() == 'b'
  236. ):
  237. prediction = '(b)'
  238. if (
  239. 'c)' in prediction
  240. or 'c )' in prediction
  241. or prediction.lower().strip() == 'c'
  242. ):
  243. prediction = '(c)'
  244. if (
  245. 'd)' in prediction
  246. or 'd )' in prediction
  247. or prediction.lower().strip() == 'd'
  248. ):
  249. prediction = '(d)'
  250. if (
  251. '(a)' in prediction
  252. or '(b)' in prediction
  253. or '(c)' in prediction
  254. or '(d)' in prediction
  255. ):
  256. prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"'
  257. # If the prediction is empty, use dummy '0'
  258. if not prediction:
  259. prediction = '0'
  260. # Converting the string answer to a number/list/bool/option
  261. try:
  262. prediction = eval(prediction)
  263. except Exception:
  264. LOGGER.warning(
  265. f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}'
  266. )
  267. return None # failed to convert the answer
  268. # Performing common type conversion
  269. if isinstance(prediction, (set, tuple)):
  270. prediction = list(prediction)
  271. if isinstance(prediction[0], complex):
  272. prediction = [tmp.real for tmp in prediction]
  273. elif isinstance(prediction[0], Rational):
  274. prediction = [float(tmp) for tmp in prediction]
  275. elif isinstance(prediction, np.ndarray):
  276. prediction = prediction.tolist()
  277. else:
  278. if isinstance(prediction, complex):
  279. prediction = prediction.real
  280. elif isinstance(prediction, Rational):
  281. prediction = float(prediction)
  282. return prediction
  283. def success(self, solution: str) -> bool:
  284. """This checks whether the given solution can complete the current task."""
  285. # Follow the implementation from TheoremQA
  286. # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1
  287. prediction = self.extract_answer(solution)
  288. LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}')
  289. answer_type = self._answer_type
  290. gt = self.extract_answer(self.reference)
  291. if isinstance(prediction, (str, int, float, list)):
  292. # Comparing prediction against the reference
  293. if answer_type in ['bool', 'option', 'Option']:
  294. cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
  295. elif answer_type == 'integer':
  296. cur_correct = int(compare_two_numbers(prediction, gt))
  297. elif answer_type == 'float':
  298. cur_correct = int(compare_two_numbers(prediction, gt))
  299. elif answer_type in ['list of integer', 'list of float']:
  300. cur_correct = int(compare_two_list(prediction, gt))
  301. else:
  302. cur_correct = 0
  303. return bool(cur_correct)