inverse_normalize.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from argparse import ArgumentParser
  4. from time import perf_counter
  5. from typing import List
  6. from fun_text_processing.text_normalization.data_loader_utils import load_file, write_file
  7. from fun_text_processing.text_normalization.normalize import Normalizer
  8. from fun_text_processing.text_normalization.token_parser import TokenParser
  9. class InverseNormalizer(Normalizer):
  10. """
  11. Inverse normalizer that converts text from spoken to written form. Useful for ASR postprocessing.
  12. Input is expected to have no punctuation outside of approstrophe (') and dash (-) and be lower cased.
  13. Args:
  14. lang: language specifying the ITN
  15. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  16. overwrite_cache: set to True to overwrite .far files
  17. """
  18. def __init__(self, lang: str = 'en', cache_dir: str = None, overwrite_cache: bool = False,
  19. enable_standalone_number: bool = True,
  20. enable_0_to_9: bool = True):
  21. if lang == 'en':
  22. from fun_text_processing.inverse_text_normalization.en.taggers.tokenize_and_classify import ClassifyFst
  23. from fun_text_processing.inverse_text_normalization.en.verbalizers.verbalize_final import (
  24. VerbalizeFinalFst,
  25. )
  26. elif lang == 'id':
  27. from fun_text_processing.inverse_text_normalization.id.taggers.tokenize_and_classify import ClassifyFst
  28. from fun_text_processing.inverse_text_normalization.id.verbalizers.verbalize_final import (
  29. VerbalizeFinalFst,
  30. )
  31. elif lang == 'ja':
  32. from fun_text_processing.inverse_text_normalization.ja.taggers.tokenize_and_classify import ClassifyFst
  33. from fun_text_processing.inverse_text_normalization.ja.verbalizers.verbalize_final import (
  34. VerbalizeFinalFst,
  35. )
  36. elif lang == 'es':
  37. from fun_text_processing.inverse_text_normalization.es.taggers.tokenize_and_classify import ClassifyFst
  38. from fun_text_processing.inverse_text_normalization.es.verbalizers.verbalize_final import (
  39. VerbalizeFinalFst,
  40. )
  41. elif lang == 'pt':
  42. from fun_text_processing.inverse_text_normalization.pt.taggers.tokenize_and_classify import ClassifyFst
  43. from fun_text_processing.inverse_text_normalization.pt.verbalizers.verbalize_final import (
  44. VerbalizeFinalFst,
  45. )
  46. elif lang == 'ru':
  47. from fun_text_processing.inverse_text_normalization.ru.taggers.tokenize_and_classify import ClassifyFst
  48. from fun_text_processing.inverse_text_normalization.ru.verbalizers.verbalize_final import (
  49. VerbalizeFinalFst,
  50. )
  51. elif lang == 'de':
  52. from fun_text_processing.inverse_text_normalization.de.taggers.tokenize_and_classify import ClassifyFst
  53. from fun_text_processing.inverse_text_normalization.de.verbalizers.verbalize_final import (
  54. VerbalizeFinalFst,
  55. )
  56. elif lang == 'fr':
  57. from fun_text_processing.inverse_text_normalization.fr.taggers.tokenize_and_classify import ClassifyFst
  58. from fun_text_processing.inverse_text_normalization.fr.verbalizers.verbalize_final import (
  59. VerbalizeFinalFst,
  60. )
  61. elif lang == 'vi':
  62. from fun_text_processing.inverse_text_normalization.vi.taggers.tokenize_and_classify import ClassifyFst
  63. from fun_text_processing.inverse_text_normalization.vi.verbalizers.verbalize_final import (
  64. VerbalizeFinalFst,
  65. )
  66. elif lang == 'ko':
  67. from fun_text_processing.inverse_text_normalization.ko.taggers.tokenize_and_classify import ClassifyFst
  68. from fun_text_processing.inverse_text_normalization.ko.verbalizers.verbalize_final import (
  69. VerbalizeFinalFst,
  70. )
  71. elif lang == 'zh':
  72. from fun_text_processing.inverse_text_normalization.zh.taggers.tokenize_and_classify import ClassifyFst
  73. from fun_text_processing.inverse_text_normalization.zh.verbalizers.verbalize_final import (
  74. VerbalizeFinalFst,
  75. )
  76. elif lang == 'tl':
  77. from fun_text_processing.inverse_text_normalization.tl.taggers.tokenize_and_classify import ClassifyFst
  78. from fun_text_processing.inverse_text_normalization.tl.verbalizers.verbalize_final import (
  79. VerbalizeFinalFst,
  80. )
  81. self.tagger = ClassifyFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache)
  82. self.verbalizer = VerbalizeFinalFst()
  83. self.parser = TokenParser()
  84. self.lang = lang
  85. self.convert_number = enable_standalone_number
  86. self.enable_0_to_9 = enable_0_to_9
  87. def inverse_normalize_list(self, texts: List[str], verbose=False) -> List[str]:
  88. """
  89. NeMo inverse text normalizer
  90. Args:
  91. texts: list of input strings
  92. verbose: whether to print intermediate meta information
  93. Returns converted list of input strings
  94. """
  95. # print(texts)
  96. return self.normalize_list(texts=texts, verbose=verbose)
  97. def inverse_normalize(self, text: str, verbose: bool) -> str:
  98. """
  99. Main function. Inverse normalizes tokens from spoken to written form
  100. e.g. twelve kilograms -> 12 kg
  101. Args:
  102. text: string that may include semiotic classes
  103. verbose: whether to print intermediate meta information
  104. Returns: written form
  105. """
  106. print(text)
  107. return self.normalize(text=text, verbose=verbose)
  108. def str2bool(s, default=False):
  109. s = s.lower()
  110. if s == 'true':
  111. return True
  112. elif s == 'false':
  113. return False
  114. else:
  115. return default
  116. def parse_args():
  117. parser = ArgumentParser()
  118. input = parser.add_mutually_exclusive_group()
  119. input.add_argument("--text", dest="input_string", help="input string", type=str)
  120. input.add_argument("--input_file", dest="input_file", help="input file path", type=str)
  121. parser.add_argument('--output_file', dest="output_file", help="output file path", type=str)
  122. parser.add_argument(
  123. "--language", help="language", choices=['en', 'id', 'ja', 'de', 'es', 'pt', 'ru', 'fr', 'vi', 'ko', 'zh', 'tl'], default="en", type=str
  124. )
  125. parser.add_argument("--verbose", help="print info for debugging", action='store_true')
  126. parser.add_argument("--overwrite_cache", help="set to True to re-create .far grammar files", action="store_true")
  127. parser.add_argument(
  128. "--cache_dir",
  129. help="path to a dir with .far grammar file. Set to None to avoid using cache",
  130. default=None,
  131. type=str,
  132. )
  133. parser.add_argument('--enable_standalone_number', type=str,
  134. default='True',
  135. help='enable standalone number')
  136. parser.add_argument('--enable_0_to_9', type=str,
  137. default='True',
  138. help='enable convert number 0 to 9')
  139. return parser.parse_args()
  140. if __name__ == "__main__":
  141. args = parse_args()
  142. start_time = perf_counter()
  143. if args.language == 'ja':
  144. inverse_normalizer = InverseNormalizer(lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache,
  145. enable_standalone_number=str2bool(args.enable_standalone_number),
  146. enable_0_to_9=str2bool(args.enable_0_to_9))
  147. else:
  148. inverse_normalizer = InverseNormalizer(
  149. lang=args.language, cache_dir=args.cache_dir, overwrite_cache=args.overwrite_cache
  150. )
  151. print(f'Time to generate graph: {round(perf_counter() - start_time, 2)} sec')
  152. if args.input_string:
  153. print(inverse_normalizer.inverse_normalize(args.input_string, verbose=args.verbose))
  154. elif args.input_file:
  155. print("Loading data: " + args.input_file)
  156. data = load_file(args.input_file)
  157. print("- Data: " + str(len(data)) + " sentences")
  158. prediction = inverse_normalizer.inverse_normalize_list(data, verbose=args.verbose)
  159. if args.output_file:
  160. write_file(args.output_file, prediction)
  161. print(f"- Denormalized. Writing out to {args.output_file}")
  162. else:
  163. print(prediction)