scorer.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import re
  2. import string
  3. import warnings
  4. def normalize_number_str(number_str: str) -> float:
  5. # we replace these common units and commas to allow
  6. # conversion to float
  7. for char in ['$', '%', ',']:
  8. number_str = number_str.replace(char, '')
  9. try:
  10. return float(number_str)
  11. except ValueError:
  12. print(f'String {number_str} cannot be normalized to number str.')
  13. return float('inf')
  14. def split_string(
  15. s: str,
  16. char_list: list[str] = None,
  17. ) -> list[str]:
  18. if char_list is None:
  19. char_list = [',', ';']
  20. pattern = f"[{''.join(char_list)}]"
  21. return re.split(pattern, s)
  22. def question_scorer(
  23. model_answer: str,
  24. ground_truth: str,
  25. ) -> bool:
  26. def is_float(element: any) -> bool:
  27. try:
  28. float(element)
  29. return True
  30. except ValueError:
  31. return False
  32. # if gt is a number
  33. if is_float(ground_truth):
  34. print(f'Evaluating {model_answer} as a number.')
  35. normalized_answer = normalize_number_str(model_answer)
  36. return normalized_answer == float(ground_truth)
  37. # if gt is a list
  38. elif any(char in ground_truth for char in [',', ';']):
  39. print(f'Evaluating {model_answer} as a comma separated list.')
  40. # question with the fish: normalization removes punct
  41. gt_elems = split_string(ground_truth)
  42. ma_elems = split_string(model_answer)
  43. # check length is the same
  44. if len(gt_elems) != len(ma_elems):
  45. warnings.warn(
  46. 'Answer lists have different lengths, returning False.',
  47. UserWarning,
  48. stacklevel=2,
  49. )
  50. return False
  51. # compare each element as float or str
  52. comparisons = []
  53. for ma_elem, gt_elem in zip(ma_elems, gt_elems):
  54. if is_float(gt_elem):
  55. normalized_ma_elem = normalize_number_str(ma_elem)
  56. comparisons.append(normalized_ma_elem == float(gt_elem))
  57. else:
  58. # we do not remove punct since comparisons can include punct
  59. comparisons.append(
  60. normalize_str(ma_elem, remove_punct=False)
  61. == normalize_str(gt_elem, remove_punct=False)
  62. )
  63. return all(comparisons)
  64. # if gt is a str
  65. else:
  66. print(f'Evaluating {model_answer} as a string.')
  67. return normalize_str(model_answer) == normalize_str(ground_truth)
  68. def normalize_str(input_str, remove_punct=True) -> str:
  69. """Normalize a string by:
  70. - Removing all white spaces
  71. - Optionally removing punctuation (if remove_punct is True)
  72. - Converting to lowercase
  73. Parameters:
  74. - input_str: str, the string to normalize
  75. - remove_punct: bool, whether to remove punctuation (default: True)
  76. Returns:
  77. - str, the normalized string
  78. """
  79. # Remove all white spaces. Required e.g for seagull vs. sea gull
  80. no_spaces = re.sub(r'\s', '', input_str)
  81. # Remove punctuation, if specified.
  82. if remove_punct:
  83. translator = str.maketrans('', '', string.punctuation)
  84. return no_spaces.lower().translate(translator)
  85. else:
  86. return no_spaces.lower()