graph_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import os
  2. import string
  3. from pathlib import Path
  4. from typing import Dict
  5. import pynini
  6. from pynini import Far
  7. from pynini.export import export
  8. from pynini.lib import byte, pynutil, utf8
  9. DAMO_CHAR = utf8.VALID_UTF8_CHAR
  10. DAMO_DIGIT = byte.DIGIT
  11. DAMO_LOWER = pynini.union(*string.ascii_lowercase).optimize()
  12. DAMO_UPPER = pynini.union(*string.ascii_uppercase).optimize()
  13. DAMO_ALPHA = pynini.union(DAMO_LOWER, DAMO_UPPER).optimize()
  14. DAMO_ALNUM = pynini.union(DAMO_DIGIT, DAMO_ALPHA).optimize()
  15. DAMO_HEX = pynini.union(*string.hexdigits).optimize()
  16. DAMO_NON_BREAKING_SPACE = "\u00A0"
  17. DAMO_SPACE = " "
  18. DAMO_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", "\u00A0").optimize()
  19. DAMO_NOT_SPACE = pynini.difference(DAMO_CHAR, DAMO_WHITE_SPACE).optimize()
  20. DAMO_NOT_QUOTE = pynini.difference(DAMO_CHAR, r'"').optimize()
  21. DAMO_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
  22. DAMO_GRAPH = pynini.union(DAMO_ALNUM, DAMO_PUNCT).optimize()
  23. DAMO_SIGMA = pynini.closure(DAMO_CHAR)
  24. delete_space = pynutil.delete(pynini.closure(DAMO_WHITE_SPACE))
  25. insert_space = pynutil.insert(" ")
  26. delete_extra_space = pynini.cross(pynini.closure(DAMO_WHITE_SPACE, 1), " ")
  27. # French frequently compounds numbers with hyphen.
  28. delete_hyphen = pynutil.delete(pynini.closure("-", 0, 1))
  29. insert_hyphen = pynutil.insert("-")
  30. TO_LOWER = pynini.union(*[pynini.cross(x, y) for x, y in zip(string.ascii_uppercase, string.ascii_lowercase)])
  31. TO_UPPER = pynini.invert(TO_LOWER)
  32. def generator_main(file_name: str, graphs: Dict[str, pynini.FstLike]):
  33. """
  34. Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
  35. Args:
  36. file_name: exported file name
  37. graphs: Mapping of a rule name and Pynini WFST graph to be exported
  38. """
  39. exporter = export.Exporter(file_name)
  40. for rule, graph in graphs.items():
  41. exporter[rule] = graph.optimize()
  42. exporter.close()
  43. print(f"Created {file_name}")
  44. def convert_space(fst) -> "pynini.FstLike":
  45. """
  46. Converts space to nonbreaking space.
  47. Used only in tagger grammars for transducing token values within quotes, e.g. name: "hello kitty"
  48. This is making transducer significantly slower, so only use when there could be potential spaces within quotes, otherwise leave it.
  49. Args:
  50. fst: input fst
  51. Returns output fst where breaking spaces are converted to non breaking spaces
  52. """
  53. return fst @ pynini.cdrewrite(pynini.cross(DAMO_SPACE, DAMO_NON_BREAKING_SPACE), "", "", DAMO_SIGMA)
  54. class GraphFst:
  55. """
  56. Base class for all grammar fsts.
  57. Args:
  58. name: name of grammar class
  59. kind: either 'classify' or 'verbalize'
  60. deterministic: if True will provide a single transduction option,
  61. for False multiple transduction are generated (used for audio-based normalization)
  62. """
  63. def __init__(self, name: str, kind: str, deterministic: bool = True):
  64. self.name = name
  65. self.kind = kind
  66. self._fst = None
  67. self.deterministic = deterministic
  68. self.far_path = Path(os.path.dirname(__file__) + "/grammars/" + kind + "/" + name + ".far")
  69. if self.far_exist():
  70. self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
  71. def far_exist(self) -> bool:
  72. """
  73. Returns true if FAR can be loaded
  74. """
  75. return self.far_path.exists()
  76. @property
  77. def fst(self) -> "pynini.FstLike":
  78. return self._fst
  79. @fst.setter
  80. def fst(self, fst):
  81. self._fst = fst
  82. def add_tokens(self, fst) -> "pynini.FstLike":
  83. """
  84. Wraps class name around to given fst
  85. Args:
  86. fst: input fst
  87. Returns:
  88. Fst: fst
  89. """
  90. return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
  91. def delete_tokens(self, fst) -> "pynini.FstLike":
  92. """
  93. Deletes class name wrap around output of given fst
  94. Args:
  95. fst: input fst
  96. Returns:
  97. Fst: fst
  98. """
  99. res = (
  100. pynutil.delete(f"{self.name}")
  101. + delete_space
  102. + pynutil.delete("{")
  103. + delete_space
  104. + fst
  105. + delete_space
  106. + pynutil.delete("}")
  107. )
  108. return res @ pynini.cdrewrite(pynini.cross("\u00A0", " "), "", "", DAMO_SIGMA)