run_evaluate.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from argparse import ArgumentParser
  2. from fun_text_processing.text_normalization.data_loader_utils import (
  3. evaluate,
  4. known_types,
  5. load_files,
  6. training_data_to_sentences,
  7. training_data_to_tokens,
  8. )
  9. from fun_text_processing.text_normalization.normalize import Normalizer
  10. '''
  11. Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
  12. like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
  13. '''
  14. def parse_args():
  15. parser = ArgumentParser()
  16. parser.add_argument("--input", help="input file path", type=str)
  17. parser.add_argument("--lang", help="language", choices=['en'], default="en", type=str)
  18. parser.add_argument(
  19. "--input_case", help="input capitalization", choices=["lower_cased", "cased"], default="cased", type=str
  20. )
  21. parser.add_argument(
  22. "--cat",
  23. dest="category",
  24. help="focus on class only (" + ", ".join(known_types) + ")",
  25. type=str,
  26. default=None,
  27. choices=known_types,
  28. )
  29. parser.add_argument("--filter", action='store_true', help="clean data for normalization purposes")
  30. return parser.parse_args()
  31. if __name__ == "__main__":
  32. # Example usage:
  33. # python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
  34. args = parse_args()
  35. if args.lang == 'en':
  36. from fun_text_processing.text_normalization.en.clean_eval_data import filter_loaded_data
  37. file_path = args.input
  38. normalizer = Normalizer(input_case=args.input_case, lang=args.lang)
  39. print("Loading training data: " + file_path)
  40. training_data = load_files([file_path])
  41. if args.filter:
  42. training_data = filter_loaded_data(training_data)
  43. if args.category is None:
  44. print("Sentence level evaluation...")
  45. sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
  46. print("- Data: " + str(len(sentences_normalized)) + " sentences")
  47. sentences_prediction = normalizer.normalize_list(sentences_un_normalized)
  48. print("- Normalized. Evaluating...")
  49. sentences_accuracy = evaluate(
  50. preds=sentences_prediction, labels=sentences_normalized, input=sentences_un_normalized
  51. )
  52. print("- Accuracy: " + str(sentences_accuracy))
  53. print("Token level evaluation...")
  54. tokens_per_type = training_data_to_tokens(training_data, category=args.category)
  55. token_accuracy = {}
  56. for token_type in tokens_per_type:
  57. print("- Token type: " + token_type)
  58. tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
  59. print(" - Data: " + str(len(tokens_normalized)) + " tokens")
  60. tokens_prediction = normalizer.normalize_list(tokens_un_normalized)
  61. print(" - Denormalized. Evaluating...")
  62. token_accuracy[token_type] = evaluate(
  63. preds=tokens_prediction, labels=tokens_normalized, input=tokens_un_normalized
  64. )
  65. print(" - Accuracy: " + str(token_accuracy[token_type]))
  66. token_count_per_type = {token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type}
  67. token_weighted_accuracy = [
  68. token_count_per_type[token_type] * accuracy for token_type, accuracy in token_accuracy.items()
  69. ]
  70. print("- Accuracy: " + str(sum(token_weighted_accuracy) / sum(token_count_per_type.values())))
  71. print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
  72. print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
  73. for token_type in token_accuracy:
  74. if token_type not in known_types:
  75. raise ValueError("Unexpected token type: " + token_type)
  76. if args.category is None:
  77. c1 = ['Class', 'sent level'] + known_types
  78. c2 = ['Num Tokens', len(sentences_normalized)] + [
  79. token_count_per_type[known_type] if known_type in tokens_per_type else '0' for known_type in known_types
  80. ]
  81. c3 = ['Normalization', sentences_accuracy] + [
  82. token_accuracy[known_type] if known_type in token_accuracy else '0' for known_type in known_types
  83. ]
  84. for i in range(len(c1)):
  85. print(f'{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}')
  86. else:
  87. print(f'numbers\t{token_count_per_type[args.category]}')
  88. print(f'Normalization\t{token_accuracy[args.category]}')