run_evaluate.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright NeMo (https://github.com/NVIDIA/NeMo). All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentParser
  15. from fun_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer
  16. from fun_text_processing.text_normalization.data_loader_utils import (
  17. evaluate,
  18. known_types,
  19. load_files,
  20. training_data_to_sentences,
  21. training_data_to_tokens,
  22. )
  23. '''
  24. Runs Evaluation on data in the format of : <semiotic class>\t<unnormalized text>\t<`self` if trivial class or normalized text>
  25. like the Google text normalization data https://www.kaggle.com/richardwilliamsproat/text-normalization-for-english-russian-and-polish
  26. '''
  27. def parse_args():
  28. parser = ArgumentParser()
  29. parser.add_argument("--input", help="input file path", type=str)
  30. parser.add_argument(
  31. "--lang", help="language", choices=['en', 'id', 'ja', 'de', 'es', 'pt', 'ru', 'fr', 'vi', 'ko', 'zh', 'fil'], default="en", type=str
  32. )
  33. parser.add_argument(
  34. "--cat",
  35. dest="category",
  36. help="focus on class only (" + ", ".join(known_types) + ")",
  37. type=str,
  38. default=None,
  39. choices=known_types,
  40. )
  41. parser.add_argument("--filter", action='store_true', help="clean data for inverse normalization purposes")
  42. return parser.parse_args()
  43. if __name__ == "__main__":
  44. # Example usage:
  45. # python run_evaluate.py --input=<INPUT> --cat=<CATEGORY> --filter
  46. args = parse_args()
  47. if args.lang == 'en':
  48. from fun_text_processing.inverse_text_normalization.en.clean_eval_data import filter_loaded_data
  49. file_path = args.input
  50. inverse_normalizer = InverseNormalizer()
  51. print("Loading training data: " + file_path)
  52. training_data = load_files([file_path])
  53. if args.filter:
  54. training_data = filter_loaded_data(training_data)
  55. if args.category is None:
  56. print("Sentence level evaluation...")
  57. sentences_un_normalized, sentences_normalized, _ = training_data_to_sentences(training_data)
  58. print("- Data: " + str(len(sentences_normalized)) + " sentences")
  59. sentences_prediction = inverse_normalizer.inverse_normalize_list(sentences_normalized)
  60. print("- Denormalized. Evaluating...")
  61. sentences_accuracy = evaluate(
  62. preds=sentences_prediction, labels=sentences_un_normalized, input=sentences_normalized
  63. )
  64. print("- Accuracy: " + str(sentences_accuracy))
  65. print("Token level evaluation...")
  66. tokens_per_type = training_data_to_tokens(training_data, category=args.category)
  67. token_accuracy = {}
  68. for token_type in tokens_per_type:
  69. print("- Token type: " + token_type)
  70. tokens_un_normalized, tokens_normalized = tokens_per_type[token_type]
  71. print(" - Data: " + str(len(tokens_normalized)) + " tokens")
  72. tokens_prediction = inverse_normalizer.inverse_normalize_list(tokens_normalized)
  73. print(" - Denormalized. Evaluating...")
  74. token_accuracy[token_type] = evaluate(tokens_prediction, tokens_un_normalized, input=tokens_normalized)
  75. print(" - Accuracy: " + str(token_accuracy[token_type]))
  76. token_count_per_type = {token_type: len(tokens_per_type[token_type][0]) for token_type in tokens_per_type}
  77. token_weighted_accuracy = [
  78. token_count_per_type[token_type] * accuracy for token_type, accuracy in token_accuracy.items()
  79. ]
  80. print("- Accuracy: " + str(sum(token_weighted_accuracy) / sum(token_count_per_type.values())))
  81. print(" - Total: " + str(sum(token_count_per_type.values())), '\n')
  82. for token_type in token_accuracy:
  83. if token_type not in known_types:
  84. raise ValueError("Unexpected token type: " + token_type)
  85. if args.category is None:
  86. c1 = ['Class', 'sent level'] + known_types
  87. c2 = ['Num Tokens', len(sentences_normalized)] + [
  88. token_count_per_type[known_type] if known_type in tokens_per_type else '0' for known_type in known_types
  89. ]
  90. c3 = ["Denormalization", sentences_accuracy] + [
  91. token_accuracy[known_type] if known_type in token_accuracy else '0' for known_type in known_types
  92. ]
  93. for i in range(len(c1)):
  94. print(f'{str(c1[i]):10s} | {str(c2[i]):10s} | {str(c3[i]):5s}')
  95. else:
  96. print(f'numbers\t{token_count_per_type[args.category]}')
  97. print(f'Denormalization\t{token_accuracy[args.category]}')