|
|
@@ -131,11 +131,9 @@ class MultipleChoiceTask(Task):
|
|
|
|
|
|
|
|
|
def compare_two_numbers(p, gt):
|
|
|
- if isinstance(p, int) or isinstance(p, float):
|
|
|
+ if isinstance(p, (int, float)):
|
|
|
pass
|
|
|
- elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str):
|
|
|
- return False
|
|
|
- elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict):
|
|
|
+ elif isinstance(p, (bool, complex, dict, list, str, tuple)):
|
|
|
return False
|
|
|
else:
|
|
|
raise ValueError(p)
|
|
|
@@ -227,8 +225,8 @@ class TheoremqaTask(Task):
|
|
|
prediction = prediction.replace('°', '')
|
|
|
|
|
|
# Detect the boolean keyword in the generation
|
|
|
- if prediction in ['true', 'yes', 'false', 'no']:
|
|
|
- if prediction == 'true' or prediction == 'yes':
|
|
|
+ if prediction in ('true', 'yes', 'false', 'no'):
|
|
|
+ if prediction in ('true', 'yes'):
|
|
|
prediction = 'True'
|
|
|
else:
|
|
|
prediction = 'False'
|
|
|
@@ -342,7 +340,7 @@ class TheoremqaTask(Task):
|
|
|
answer_type = self._answer_type
|
|
|
gt = self.extract_answer(self.reference)
|
|
|
|
|
|
- if isinstance(prediction, (str, int, float)) or isinstance(prediction, list):
|
|
|
+ if isinstance(prediction, (str, int, float, list)):
|
|
|
# Comparing prediction against the reference
|
|
|
if answer_type in ['bool', 'option', 'Option']:
|
|
|
cur_correct = int(prediction == f'({gt})') or int(prediction == gt)
|