graph_utils.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import string
  3. from pathlib import Path
  4. from typing import Dict
  5. import pynini
  6. from fun_text_processing.inverse_text_normalization.fr.utils import get_abs_path
  7. from pynini import Far
  8. from pynini.examples import plurals
  9. from pynini.export import export
  10. from pynini.lib import byte, pynutil, utf8
  11. DAMO_CHAR = utf8.VALID_UTF8_CHAR
  12. DAMO_DIGIT = byte.DIGIT
  13. DAMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
  14. DAMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
  15. DAMO_ALPHA = pynini.union(DAMO_LOWER, DAMO_UPPER).optimize()
  16. DAMO_ALNUM = pynini.union(DAMO_DIGIT, DAMO_ALPHA).optimize()
  17. DAMO_HEX = pynini.union(*string.hexdigits).optimize()
  18. DAMO_NON_BREAKING_SPACE = u"\u00A0"
  19. DAMO_SPACE = " "
  20. DAMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize()
  21. DAMO_NOT_SPACE = pynini.difference(DAMO_CHAR, DAMO_WHITE_SPACE).optimize()
  22. DAMO_NOT_QUOTE = pynini.difference(DAMO_CHAR, r'"').optimize()
  23. DAMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
  24. DAMO_GRAPH = pynini.union(DAMO_ALNUM, DAMO_PUNCT).optimize()
  25. DAMO_SIGMA = pynini.closure(DAMO_CHAR)
  26. delete_space = pynutil.delete(pynini.closure(DAMO_WHITE_SPACE))
  27. insert_space = pynutil.insert(" ")
  28. delete_extra_space = pynini.cross(pynini.closure(DAMO_WHITE_SPACE, 1), " ")
  29. # French frequently compounds numbers with hyphen.
  30. delete_hyphen = pynutil.delete(pynini.closure("-", 0, 1))
  31. insert_hyphen = pynutil.insert("-")
  32. suppletive = pynini.string_file(get_abs_path("data/suppletive.tsv"))
  33. _s = DAMO_SIGMA + pynutil.insert("s")
  34. _x = DAMO_SIGMA + pynini.string_map([("eau"), ("eu"), ("ou")]) + pynutil.insert("x")
  35. _aux = DAMO_SIGMA + pynini.string_map([("al", "aux"), ("ail", "aux")])
  36. graph_plural = plurals._priority_union(
  37. suppletive, plurals._priority_union(_s, pynini.union(_x, _aux), DAMO_SIGMA), DAMO_SIGMA
  38. ).optimize()
  39. SINGULAR_TO_PLURAL = graph_plural
  40. PLURAL_TO_SINGULAR = pynini.invert(graph_plural)
  41. TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
  42. TO_UPPER = pynini.invert(TO_LOWER)
  43. def generator_main(file_name: str, graphs: Dict[str, pynini.FstLike]):
  44. """
  45. Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
  46. Args:
  47. file_name: exported file name
  48. graphs: Mapping of a rule name and Pynini WFST graph to be exported
  49. """
  50. exporter = export.Exporter(file_name)
  51. for rule, graph in graphs.items():
  52. exporter[rule] = graph.optimize()
  53. exporter.close()
  54. print(f'Created {file_name}')
  55. def get_plurals(fst):
  56. """
  57. Given singular returns plurals
  58. Args:
  59. fst: Fst
  60. Returns plurals to given singular forms
  61. """
  62. return SINGULAR_TO_PLURAL @ fst
  63. def get_singulars(fst):
  64. """
  65. Given plural returns singulars
  66. Args:
  67. fst: Fst
  68. Returns singulars to given plural forms
  69. """
  70. return PLURAL_TO_SINGULAR @ fst
  71. def convert_space(fst) -> 'pynini.FstLike':
  72. """
  73. Converts space to nonbreaking space.
  74. Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
  75. This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
  76. Args:
  77. fst: input fst
  78. Returns output fst where breaking spaces are converted to non breaking spaces
  79. """
  80. return fst @ pynini.cdrewrite(pynini.cross(DAMO_SPACE, DAMO_NON_BREAKING_SPACE), "", "", DAMO_SIGMA)
  81. class GraphFst:
  82. """
  83. Base class for all grammar fsts.
  84. Args:
  85. name: name of grammar class
  86. kind: either 'classify' or 'verbalize'
  87. deterministic: if True will provide a single transduction option,
  88. for False multiple transduction are generated (used for audio-based normalization)
  89. """
  90. def __init__(self, name: str, kind: str, deterministic: bool = True):
  91. self.name = name
  92. self.kind = kind
  93. self._fst = None
  94. self.deterministic = deterministic
  95. self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far')
  96. if self.far_exist():
  97. self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
  98. def far_exist(self) -> bool:
  99. """
  100. Returns true if FAR can be loaded
  101. """
  102. return self.far_path.exists()
  103. @property
  104. def fst(self) -> 'pynini.FstLike':
  105. return self._fst
  106. @fst.setter
  107. def fst(self, fst):
  108. self._fst = fst
  109. def add_tokens(self, fst) -> 'pynini.FstLike':
  110. """
  111. Wraps class name around to given fst
  112. Args:
  113. fst: input fst
  114. Returns:
  115. Fst: fst
  116. """
  117. return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
  118. def delete_tokens(self, fst) -> 'pynini.FstLike':
  119. """
  120. Deletes class name wrap around output of given fst
  121. Args:
  122. fst: input fst
  123. Returns:
  124. Fst: fst
  125. """
  126. res = (
  127. pynutil.delete(f"{self.name}")
  128. + delete_space
  129. + pynutil.delete("{")
  130. + delete_space
  131. + fst
  132. + delete_space
  133. + pynutil.delete("}")
  134. )
  135. return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", DAMO_SIGMA)