graph_utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright NeMo (https://github.com/NVIDIA/NeMo). All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import string
  16. from pathlib import Path
  17. from typing import Dict
  18. import pynini
  19. from pynini import Far
  20. from pynini.export import export
  21. from pynini.lib import byte, pynutil, utf8
  22. DAMO_CHAR = utf8.VALID_UTF8_CHAR
  23. DAMO_DIGIT = byte.DIGIT
  24. DAMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
  25. DAMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
  26. DAMO_ALPHA = pynini.union(DAMO_LOWER, DAMO_UPPER).optimize()
  27. DAMO_ALNUM = pynini.union(DAMO_DIGIT, DAMO_ALPHA).optimize()
  28. DAMO_HEX = pynini.union(*string.hexdigits).optimize()
  29. DAMO_NON_BREAKING_SPACE = "\u00A0"
  30. DAMO_SPACE = " "
  31. DAMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", "\u00A0").optimize()
  32. DAMO_NOT_SPACE = pynini.difference(DAMO_CHAR, DAMO_WHITE_SPACE).optimize()
  33. DAMO_NOT_QUOTE = pynini.difference(DAMO_CHAR, r'"').optimize()
  34. DAMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
  35. DAMO_GRAPH = pynini.union(DAMO_ALNUM, DAMO_PUNCT).optimize()
  36. DAMO_SIGMA = pynini.closure(DAMO_CHAR)
  37. delete_space = pynutil.delete(pynini.closure(DAMO_WHITE_SPACE))
  38. insert_space = pynutil.insert(" ")
  39. delete_extra_space = pynini.cross(pynini.closure(DAMO_WHITE_SPACE, 1), " ")
  40. # French frequently compounds numbers with hyphen.
  41. delete_hyphen = pynutil.delete(pynini.closure("-", 0, 1))
  42. insert_hyphen = pynutil.insert("-")
  43. TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
  44. TO_UPPER = pynini.invert(TO_LOWER)
  45. def generator_main(file_name: str, graphs: Dict[str, pynini.FstLike]):
  46. """
  47. Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
  48. Args:
  49. file_name: exported file name
  50. graphs: Mapping of a rule name and Pynini WFST graph to be exported
  51. """
  52. exporter = export.Exporter(file_name)
  53. for rule, graph in graphs.items():
  54. exporter[rule] = graph.optimize()
  55. exporter.close()
  56. print(f"Created {file_name}")
  57. def convert_space(fst) -> "pynini.FstLike":
  58. """
  59. Converts space to nonbreaking space.
  60. Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
  61. This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
  62. Args:
  63. fst: input fst
  64. Returns output fst where breaking spaces are converted to non breaking spaces
  65. """
  66. return fst @ pynini.cdrewrite(pynini.cross(DAMO_SPACE, DAMO_NON_BREAKING_SPACE), "", "", DAMO_SIGMA)
  67. class GraphFst:
  68. """
  69. Base class for all grammar fsts.
  70. Args:
  71. name: name of grammar class
  72. kind: either 'classify' or 'verbalize'
  73. deterministic: if True will provide a single transduction option,
  74. for False multiple transduction are generated (used for audio-based normalization)
  75. """
  76. def __init__(self, name: str, kind: str, deterministic: bool = True):
  77. self.name = name
  78. self.kind = kind
  79. self._fst = None
  80. self.deterministic = deterministic
  81. self.far_path = Path(os.path.dirname(__file__) + "/grammars/" + kind + "/" + name + ".far")
  82. if self.far_exist():
  83. self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
  84. def far_exist(self) -> bool:
  85. """
  86. Returns true if FAR can be loaded
  87. """
  88. return self.far_path.exists()
  89. @property
  90. def fst(self) -> "pynini.FstLike":
  91. return self._fst
  92. @fst.setter
  93. def fst(self, fst):
  94. self._fst = fst
  95. def add_tokens(self, fst) -> "pynini.FstLike":
  96. """
  97. Wraps class name around to given fst
  98. Args:
  99. fst: input fst
  100. Returns:
  101. Fst: fst
  102. """
  103. return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
  104. def delete_tokens(self, fst) -> "pynini.FstLike":
  105. """
  106. Deletes class name wrap around output of given fst
  107. Args:
  108. fst: input fst
  109. Returns:
  110. Fst: fst
  111. """
  112. res = (
  113. pynutil.delete(f"{self.name}")
  114. + delete_space
  115. + pynutil.delete("{")
  116. + delete_space
  117. + fst
  118. + delete_space
  119. + pynutil.delete("}")
  120. )
  121. return res @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)