| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- import ast
- import logging
- import re
- import traceback
- from typing import Any, Optional
- import numpy as np
- from sympy import Rational
- from tasks.base import Task
- LOGGER = logging.getLogger('MINT')
- class ReasoningTask(Task):
- task_name = 'reasoning'
- def __init__(self, id: str, prompt: str, reference: str, **kwargs):
- super().__init__(**kwargs)
- self._id = id
- self._prompt = prompt.strip()
- self._reference = str(reference).strip().lower()
- def extract_answer(self, solution: str) -> Optional[str]:
- """Extract the answer from the given solution."""
- return solution.lower().strip()
- def compare_w_digits(self, reference: str, answer: str) -> bool:
- """Compare the reference and answer with digits."""
- # if reference can and answer can both be converted to floats by float()
- try:
- float(reference)
- float(answer)
- return abs(float(reference) - float(answer)) <= 0.05 * abs(float(reference))
- except ValueError:
- return reference in answer
- except Exception:
- raise ValueError(f'Cannot compare {reference} and {answer}')
- def success(self, solution: str) -> bool:
- answer = self.extract_answer(solution)
- return self.compare_w_digits(self._reference, answer)
- class MultipleChoiceTask(Task):
- """Subclass of Task for multiple choice tasks."""
- task_name = 'reasoning'
- def __init__(self, id, prompt: str, reference: str, **kwargs):
- super().__init__(**kwargs)
- self._id = id
- self.hide_options = kwargs.get('hide_options', False)
- if self.hide_options:
- self._prompt = prompt.split('Options:')[0].strip()
- else:
- self._prompt = prompt
- self._reference = reference.strip().lower()
- self._options = self.extract_options(prompt)
- # if all options can be converted to float, strictly perform hide options
- try:
- for option in self._options.values():
- float(option)
- self.hide_options = True
- except ValueError:
- pass
- self.metadata.update({'options': self._options})
- def extract_answer(self, solution: str) -> Optional[str]:
- # Extract the selected option from the solution
- solution = solution.lower().strip()
- for letter in 'abcdefghijklmnopqrstuvwxyz':
- if f'{letter})' in solution or f'{letter} )' in solution:
- print('SOLUTION', letter)
- return letter
- else:
- print('SOLUTION', solution)
- return solution
- def compare_w_digits(self, reference: str, answer: str) -> bool:
- if reference.isdigit() and answer.isdigit():
- return abs(float(reference) - float(answer)) <= 0.05 * float(reference)
- else:
- return reference in answer
- def success(self, solution: str) -> bool:
- answer = self.extract_answer(solution)
- if self.compare_w_digits(self._reference, answer):
- return True
- else:
- correct_option = self._options[self._reference]
- wrong_option_list = list(self._options.values())
- print('OPTIONS', correct_option, wrong_option_list)
- print('ANSWER', answer)
- for i in wrong_option_list:
- if i in correct_option:
- wrong_option_list.remove(i)
- for i in wrong_option_list:
- if self.compare_w_digits(i, answer) or (i in answer):
- return False
- if self.compare_w_digits(correct_option, answer) or (
- correct_option in answer
- ):
- return True
- else:
- return False
- def extract_options(self, prompt: str) -> dict:
- # Find the possible option separators (comma, semicolon, or parentheses)
- prompt = prompt.split('Options: ')[-1]
- # Extract the options using the delimiter
- options_match = prompt.split(' , ')
- options = {}
- for i in range(len(options_match)):
- option = options_match[i].strip("[]' ")
- option = option.split(')')
- letter = option[0].lower().strip()
- content = (
- option[1]
- .lower()
- .strip('.')
- .replace('. Which option is correct?', '')
- .replace('. Which one is correct?', '')
- .strip()
- )
- options.update({letter: content})
- return options
- # ==== TheoremQA ====
- def compare_two_numbers(p, gt):
- if isinstance(p, int) or isinstance(p, float):
- pass
- elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str):
- return False
- elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict):
- return False
- else:
- raise ValueError(p)
- if isinstance(gt, float):
- return within_eps(pred=p, gt=gt)
- else:
- return round(p) == gt
- def compare_two_list(pred, gt):
- if not isinstance(pred, list):
- return False
- elif len(pred) != len(gt):
- return False
- elif any([not isinstance(x, (int, float)) for x in pred]):
- return False
- else:
- pred = sorted(pred)
- gt = sorted(gt)
- return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)])
- def within_eps(pred: float, gt: float):
- eps = abs(gt) * 0.04
- if pred >= gt - eps and pred <= gt + eps:
- return True
- else:
- return False
- def parse_number_list(s: str):
- # Check if the string is a valid list by trying to parse it
- parsed_list = ast.literal_eval(s)
- return parsed_list
- def is_number(string):
- pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$'
- match = re.match(pattern, string)
- return bool(match)
- def is_scientific_number(string):
- pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$'
- match = re.match(pattern, string)
- return bool(match)
- def contain_num_and_str(string):
- pattern_str = r'[a-zA-Z]'
- pattern_num = r'[0-9]'
- return bool(re.search(pattern_str, string) and re.search(pattern_num, string))
- class TheoremqaTask(Task):
- task_name = 'reasoning'
- def __init__(self, id: str, prompt: str, reference: str, **kwargs):
- super().__init__(**kwargs)
- self._id = id
- self._prompt = (
- 'Answer the following question with a number, a list of numbers or True or False. '
- + prompt.strip()
- )
- self._reference = reference
- self._answer_type = kwargs.get('answer_type')
- def extract_answer(self, solution: str) -> Optional[Any]:
- """Extract the answer from the given solution."""
- prediction = solution
- # Following the preprocessing steps from TheoremQA
- # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170
- # Preprocessing the string [Stage 1]
- if not isinstance(prediction, str):
- prediction = str(prediction) if prediction is not None else '0'
- # Replace special tokens
- if '=' in prediction:
- prediction = prediction.split('=')[-1].strip()
- if '≈' in prediction:
- prediction = prediction.split('≈')[-1].strip()
- if '`' in prediction:
- prediction = prediction.replace('`', '')
- if '$' in prediction:
- prediction = prediction.replace('$', '')
- if '°' in prediction:
- prediction = prediction.replace('°', '')
- # Detect the boolean keyword in the generation
- if prediction in ['true', 'yes', 'false', 'no']:
- if prediction == 'true' or prediction == 'yes':
- prediction = 'True'
- else:
- prediction = 'False'
- if 'True' in prediction or 'False' in prediction:
- prediction = 'True' if 'True' in prediction else 'False'
- # Detect the approximation keyword
- if 'approximately' in prediction:
- prediction = prediction.replace('approximately', '').strip()
- if ' or ' in prediction:
- prediction = prediction.split(' or ')[0]
- # Drop the units before and after the number
- if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction):
- prediction = re.search(
- r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction
- ).group(1)
- if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction):
- prediction = re.search(
- r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction
- ).group(1)
- if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction):
- prediction = re.search(
- r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction
- ).group(1)
- if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction):
- prediction = re.search(
- r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction
- ).group(1)
- # Preprocessing the number [Stage 1]
- if '10^' in prediction:
- prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction)
- if ' x ' in prediction:
- prediction = prediction.replace(' x ', '*')
- if ' × ' in prediction:
- prediction = prediction.replace(' × ', '*')
- if is_number(prediction):
- prediction = prediction.replace(',', '')
- # Preprocessing the option [Stage 3]
- if (
- 'a)' in prediction
- or 'a )' in prediction
- or prediction.lower().strip() == 'a'
- ):
- prediction = '(a)'
- if (
- 'b)' in prediction
- or 'b )' in prediction
- or prediction.lower().strip() == 'b'
- ):
- prediction = '(b)'
- if (
- 'c)' in prediction
- or 'c )' in prediction
- or prediction.lower().strip() == 'c'
- ):
- prediction = '(c)'
- if (
- 'd)' in prediction
- or 'd )' in prediction
- or prediction.lower().strip() == 'd'
- ):
- prediction = '(d)'
- if (
- '(a)' in prediction
- or '(b)' in prediction
- or '(c)' in prediction
- or '(d)' in prediction
- ):
- prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"'
- # If the prediction is empty, use dummy '0'
- if not prediction:
- prediction = '0'
- # Converting the string answer to a number/list/bool/option
- try:
- prediction = eval(prediction)
- except Exception:
- LOGGER.warning(
- f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}'
- )
- return None # failed to convert the answer
- # Performing common type conversion
- if isinstance(prediction, (set, tuple)):
- prediction = list(prediction)
- if isinstance(prediction[0], complex):
- prediction = [tmp.real for tmp in prediction]
- elif isinstance(prediction[0], Rational):
- prediction = [float(tmp) for tmp in prediction]
- elif isinstance(prediction, np.ndarray):
- prediction = prediction.tolist()
- else:
- if isinstance(prediction, complex):
- prediction = prediction.real
- elif isinstance(prediction, Rational):
- prediction = float(prediction)
- return prediction
- def success(self, solution: str) -> bool:
- """This checks whether the given solution can complete the current task."""
- # Follow the implementation from TheoremQA
- # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1
- prediction = self.extract_answer(solution)
- LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}')
- answer_type = self._answer_type
- gt = self.extract_answer(self.reference)
- if isinstance(prediction, (str, int, float)) or isinstance(prediction, list):
- # Comparing prediction against the reference
- if answer_type in ['bool', 'option', 'Option']:
- cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
- elif answer_type == 'integer':
- cur_correct = int(compare_two_numbers(prediction, gt))
- elif answer_type == 'float':
- cur_correct = int(compare_two_numbers(prediction, gt))
- elif answer_type in ['list of integer', 'list of float']:
- cur_correct = int(compare_two_list(prediction, gt))
- else:
- cur_correct = 0
- return bool(cur_correct)
|