post_processing.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import os
  2. import pynini
  3. from fun_text_processing.text_normalization.en.graph_utils import (
  4. MIN_NEG_WEIGHT,
  5. DAMO_ALPHA,
  6. DAMO_CHAR,
  7. DAMO_SIGMA,
  8. DAMO_SPACE,
  9. generator_main,
  10. )
  11. from fun_text_processing.text_normalization.en.taggers.punctuation import PunctuationFst
  12. from pynini.lib import pynutil
  13. import logging
  14. class PostProcessingFst:
  15. """
  16. Finite state transducer that post-processing an entire sentence after verbalization is complete, e.g.
  17. removes extra spaces around punctuation marks " ( one hundred and twenty three ) " -> "(one hundred and twenty three)"
  18. Args:
  19. cache_dir: path to a dir with .far grammar file. Set to None to avoid using cache.
  20. overwrite_cache: set to True to overwrite .far files
  21. """
  22. def __init__(self, cache_dir: str = None, overwrite_cache: bool = False):
  23. far_file = None
  24. if cache_dir is not None and cache_dir != "None":
  25. os.makedirs(cache_dir, exist_ok=True)
  26. far_file = os.path.join(cache_dir, "en_tn_post_processing.far")
  27. if not overwrite_cache and far_file and os.path.exists(far_file):
  28. self.fst = pynini.Far(far_file, mode="r")["post_process_graph"]
  29. logging.info(f'Post processing graph was restored from {far_file}.')
  30. else:
  31. self.set_punct_dict()
  32. self.fst = self.get_punct_postprocess_graph()
  33. if far_file:
  34. generator_main(far_file, {"post_process_graph": self.fst})
  35. def set_punct_dict(self):
  36. self.punct_marks = {
  37. "'": [
  38. "'",
  39. '´',
  40. 'ʹ',
  41. 'ʻ',
  42. 'ʼ',
  43. 'ʽ',
  44. 'ʾ',
  45. 'ˈ',
  46. 'ˊ',
  47. 'ˋ',
  48. '˴',
  49. 'ʹ',
  50. '΄',
  51. '՚',
  52. '՝',
  53. 'י',
  54. '׳',
  55. 'ߴ',
  56. 'ߵ',
  57. 'ᑊ',
  58. 'ᛌ',
  59. '᾽',
  60. '᾿',
  61. '`',
  62. '´',
  63. '῾',
  64. '‘',
  65. '’',
  66. '‛',
  67. '′',
  68. '‵',
  69. 'ꞌ',
  70. ''',
  71. '`',
  72. '𖽑',
  73. '𖽒',
  74. ],
  75. }
  76. def get_punct_postprocess_graph(self):
  77. """
  78. Returns graph to post process punctuation marks.
  79. {``} quotes are converted to {"}. Note, if there are spaces around single quote {'}, they will be kept.
  80. By default, a space is added after a punctuation mark, and spaces are removed before punctuation marks.
  81. """
  82. punct_marks_all = PunctuationFst().punct_marks
  83. # no_space_before_punct assume no space before them
  84. quotes = ["'", "\"", "``", "«"]
  85. dashes = ["-", "—"]
  86. brackets = ["<", "{", "("]
  87. open_close_single_quotes = [
  88. ("`", "`"),
  89. ]
  90. open_close_double_quotes = [('"', '"'), ("``", "``"), ("“", "”")]
  91. open_close_symbols = open_close_single_quotes + open_close_double_quotes
  92. allow_space_before_punct = ["&"] + quotes + dashes + brackets + [k[0] for k in open_close_symbols]
  93. no_space_before_punct = [m for m in punct_marks_all if m not in allow_space_before_punct]
  94. no_space_before_punct = pynini.union(*no_space_before_punct)
  95. no_space_after_punct = pynini.union(*brackets)
  96. delete_space = pynutil.delete(" ")
  97. delete_space_optional = pynini.closure(delete_space, 0, 1)
  98. # non_punct allows space
  99. # delete space before no_space_before_punct marks, if present
  100. non_punct = pynini.difference(DAMO_CHAR, no_space_before_punct).optimize()
  101. graph = (
  102. pynini.closure(non_punct)
  103. + pynini.closure(
  104. no_space_before_punct | pynutil.add_weight(delete_space + no_space_before_punct, MIN_NEG_WEIGHT)
  105. )
  106. + pynini.closure(non_punct)
  107. )
  108. graph = pynini.closure(graph).optimize()
  109. graph = pynini.compose(
  110. graph, pynini.cdrewrite(pynini.cross("``", '"'), "", "", DAMO_SIGMA).optimize()
  111. ).optimize()
  112. # remove space after no_space_after_punct (even if there are no matching closing brackets)
  113. no_space_after_punct = pynini.cdrewrite(delete_space, no_space_after_punct, DAMO_SIGMA, DAMO_SIGMA).optimize()
  114. graph = pynini.compose(graph, no_space_after_punct).optimize()
  115. # remove space around text in quotes
  116. single_quote = pynutil.add_weight(pynini.accep("`"), MIN_NEG_WEIGHT)
  117. double_quotes = pynutil.add_weight(pynini.accep('"'), MIN_NEG_WEIGHT)
  118. quotes_graph = (
  119. single_quote + delete_space_optional + DAMO_ALPHA + DAMO_SIGMA + delete_space_optional + single_quote
  120. ).optimize()
  121. # this is to make sure multiple quotes are tagged from right to left without skipping any quotes in the left
  122. not_alpha = pynini.difference(DAMO_CHAR, DAMO_ALPHA).optimize() | pynutil.add_weight(
  123. DAMO_SPACE, MIN_NEG_WEIGHT
  124. )
  125. end = pynini.closure(pynutil.add_weight(not_alpha, MIN_NEG_WEIGHT))
  126. quotes_graph |= (
  127. double_quotes
  128. + delete_space_optional
  129. + DAMO_ALPHA
  130. + DAMO_SIGMA
  131. + delete_space_optional
  132. + double_quotes
  133. + end
  134. )
  135. quotes_graph = pynutil.add_weight(quotes_graph, MIN_NEG_WEIGHT)
  136. quotes_graph = DAMO_SIGMA + pynini.closure(DAMO_SIGMA + quotes_graph + DAMO_SIGMA)
  137. graph = pynini.compose(graph, quotes_graph).optimize()
  138. # remove space between a word and a single quote followed by s
  139. remove_space_around_single_quote = pynini.cdrewrite(
  140. delete_space_optional + pynini.union(*self.punct_marks["'"]) + delete_space,
  141. DAMO_ALPHA,
  142. pynini.union("s ", "s[EOS]"),
  143. DAMO_SIGMA,
  144. )
  145. graph = pynini.compose(graph, remove_space_around_single_quote).optimize()
  146. return graph