graph_utils.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. import string
  3. from pathlib import Path
  4. from typing import Dict
  5. import pynini
  6. from fun_text_processing.inverse_text_normalization.zh.utils import get_abs_path
  7. from pynini import Far
  8. from pynini.examples import plurals
  9. from pynini.export import export
  10. from pynini.lib import byte, pynutil, utf8
  11. DAMO_CHAR = utf8.VALID_UTF8_CHAR
  12. DAMO_DIGIT = byte.DIGIT
  13. DAMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
  14. DAMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
  15. DAMO_ALPHA = pynini.union(DAMO_LOWER, DAMO_UPPER).optimize()
  16. DAMO_ALNUM = pynini.union(DAMO_DIGIT, DAMO_ALPHA).optimize()
  17. DAMO_HEX = pynini.union(*string.hexdigits).optimize()
  18. DAMO_NON_BREAKING_SPACE = u"\u00A0"
  19. DAMO_SPACE = " "
  20. DAMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize()
  21. DAMO_NOT_SPACE = pynini.difference(DAMO_CHAR, DAMO_WHITE_SPACE).optimize()
  22. DAMO_NOT_QUOTE = pynini.difference(DAMO_CHAR, r'"').optimize()
  23. DAMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
  24. DAMO_GRAPH = pynini.union(DAMO_ALNUM, DAMO_PUNCT).optimize()
  25. DAMO_SIGMA = pynini.closure(DAMO_CHAR)
  26. delete_space = pynutil.delete(pynini.closure(DAMO_WHITE_SPACE))
  27. delete_zero_or_one_space = pynutil.delete(pynini.closure(DAMO_WHITE_SPACE, 0, 1))
  28. insert_space = pynutil.insert(" ")
  29. delete_extra_space = pynini.cross(pynini.closure(DAMO_WHITE_SPACE, 1), " ")
  30. delete_preserve_order = pynini.closure(
  31. pynutil.delete(" preserve_order: true")
  32. | (pynutil.delete(" field_order: \"") + DAMO_NOT_QUOTE + pynutil.delete("\""))
  33. )
  34. suppletive = pynini.string_file(get_abs_path("data/suppletive.tsv"))
  35. # _v = pynini.union("a", "e", "i", "o", "u")
  36. _c = pynini.union(
  37. "b", "c", "d", "f", "g", "h", "j", "k", "l", "m", "n", "p", "q", "r", "s", "t", "v", "w", "x", "y", "z"
  38. )
  39. _ies = DAMO_SIGMA + _c + pynini.cross("y", "ies")
  40. _es = DAMO_SIGMA + pynini.union("s", "sh", "ch", "x", "z") + pynutil.insert("es")
  41. _s = DAMO_SIGMA + pynutil.insert("s")
  42. graph_plural = plurals._priority_union(
  43. suppletive, plurals._priority_union(_ies, plurals._priority_union(_es, _s, DAMO_SIGMA), DAMO_SIGMA), DAMO_SIGMA
  44. ).optimize()
  45. SINGULAR_TO_PLURAL = graph_plural
  46. PLURAL_TO_SINGULAR = pynini.invert(graph_plural)
  47. TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
  48. TO_UPPER = pynini.invert(TO_LOWER)
  49. MIN_NEG_WEIGHT = -0.0001
  50. MIN_POS_WEIGHT = 0.0001
  51. def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']):
  52. """
  53. Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
  54. Args:
  55. file_name: exported file name
  56. graphs: Mapping of a rule name and Pynini WFST graph to be exported
  57. """
  58. exporter = export.Exporter(file_name)
  59. for rule, graph in graphs.items():
  60. exporter[rule] = graph.optimize()
  61. exporter.close()
  62. print(f'Created {file_name}')
  63. def get_plurals(fst):
  64. """
  65. Given singular returns plurals
  66. Args:
  67. fst: Fst
  68. Returns plurals to given singular forms
  69. """
  70. return SINGULAR_TO_PLURAL @ fst
  71. def get_singulars(fst):
  72. """
  73. Given plural returns singulars
  74. Args:
  75. fst: Fst
  76. Returns singulars to given plural forms
  77. """
  78. return PLURAL_TO_SINGULAR @ fst
  79. def convert_space(fst) -> 'pynini.FstLike':
  80. """
  81. Converts space to nonbreaking space.
  82. Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
  83. This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
  84. Args:
  85. fst: input fst
  86. Returns output fst where breaking spaces are converted to non breaking spaces
  87. """
  88. return fst @ pynini.cdrewrite(pynini.cross(DAMO_SPACE, DAMO_NON_BREAKING_SPACE), "", "", DAMO_SIGMA)
  89. class GraphFst:
  90. """
  91. Base class for all grammar fsts.
  92. Args:
  93. name: name of grammar class
  94. kind: either 'classify' or 'verbalize'
  95. deterministic: if True will provide a single transduction option,
  96. for False multiple transduction are generated (used for audio-based normalization)
  97. """
  98. def __init__(self, name: str, kind: str, deterministic: bool = True):
  99. self.name = name
  100. self.kind = kind
  101. self._fst = None
  102. self.deterministic = deterministic
  103. self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far')
  104. if self.far_exist():
  105. self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
  106. def far_exist(self) -> bool:
  107. """
  108. Returns true if FAR can be loaded
  109. """
  110. return self.far_path.exists()
  111. @property
  112. def fst(self) -> 'pynini.FstLike':
  113. return self._fst
  114. @fst.setter
  115. def fst(self, fst):
  116. self._fst = fst
  117. def add_tokens(self, fst) -> 'pynini.FstLike':
  118. """
  119. Wraps class name around to given fst
  120. Args:
  121. fst: input fst
  122. Returns:
  123. Fst: fst
  124. """
  125. return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
  126. def delete_tokens(self, fst) -> 'pynini.FstLike':
  127. """
  128. Deletes class name wrap around output of given fst
  129. Args:
  130. fst: input fst
  131. Returns:
  132. Fst: fst
  133. """
  134. res = (
  135. pynutil.delete(f"{self.name}")
  136. + delete_space
  137. + pynutil.delete("{")
  138. + delete_space
  139. + fst
  140. + delete_space
  141. + pynutil.delete("}")
  142. )
  143. return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)