graph_utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. import string
  3. from pathlib import Path
  4. from typing import Dict
  5. import pynini
  6. from pynini import Far
  7. from pynini.export import export
  8. from pynini.lib import byte, pynutil, utf8
  9. FUN_CHAR = utf8.VALID_UTF8_CHAR
  10. FUN_DIGIT = byte.DIGIT
  11. FUN_LOWER = pynini.union(*string.ascii_lowercase).optimize()
  12. FUN_UPPER = pynini.union(*string.ascii_uppercase).optimize()
  13. FUN_ALPHA = pynini.union(FUN_LOWER, FUN_UPPER).optimize()
  14. FUN_SPACE = " "
  15. FUN_WHITE_SPACE = pynini.union(" ", "\t", "\n", "\r", u"\u00A0").optimize()
  16. FUN_NOT_SPACE = pynini.difference(FUN_CHAR, FUN_WHITE_SPACE).optimize()
  17. FUN_NOT_QUOTE = pynini.difference(FUN_CHAR, r'"').optimize()
  18. FUN_PUNCT = pynini.union(*map(pynini.escape, string.punctuation)).optimize()
  19. FUN_SIGMA = pynini.closure(FUN_CHAR)
  20. delete_space = pynutil.delete(pynini.closure(FUN_WHITE_SPACE))
  21. delete_zero_or_one_space = pynutil.delete(pynini.closure(FUN_WHITE_SPACE, 0, 1))
  22. insert_space = pynutil.insert(" ")
  23. delete_extra_space = pynini.cross(pynini.closure(FUN_WHITE_SPACE, 1), " ")
  24. def generator_main(file_name: str, graphs: Dict[str, 'pynini.FstLike']):
  25. """
  26. Exports graph as OpenFst finite state archive (FAR) file with given file name and rule name.
  27. Args:
  28. file_name: exported file name
  29. graphs: Mapping of a rule name and Pynini WFST graph to be exported
  30. """
  31. exporter = export.Exporter(file_name)
  32. for rule, graph in graphs.items():
  33. exporter[rule] = graph.optimize()
  34. exporter.close()
  35. print(f'Created {file_name}')
  36. class GraphFst:
  37. """
  38. Base class for all grammar fsts.
  39. Args:
  40. name: name of grammar class
  41. kind: either 'classify' or 'verbalize'
  42. deterministic: if True will provide a single transduction option,
  43. for False multiple transduction are generated (used for audio-based normalization)
  44. """
  45. def __init__(self, name: str, kind: str, deterministic: bool = True):
  46. self.name = name
  47. self.kind = kind
  48. self._fst = None
  49. self.deterministic = deterministic
  50. self.far_path = Path(os.path.dirname(__file__) + '/grammars/' + kind + '/' + name + '.far')
  51. if self.far_exist():
  52. self._fst = Far(self.far_path, mode="r", arc_type="standard", far_type="default").get_fst()
  53. def far_exist(self) -> bool:
  54. """
  55. Returns true if FAR can be loaded
  56. """
  57. return self.far_path.exists()
  58. @property
  59. def fst(self) -> 'pynini.FstLike':
  60. return self._fst
  61. @fst.setter
  62. def fst(self, fst):
  63. self._fst = fst
  64. def add_tokens(self, fst) -> 'pynini.FstLike':
  65. """
  66. Wraps class name around to given fst
  67. Args:
  68. fst: input fst
  69. Returns:
  70. Fst: fst
  71. """
  72. return pynutil.insert(f"{self.name} {{ ") + fst + pynutil.insert(" }")
  73. def delete_tokens(self, fst) -> 'pynini.FstLike':
  74. """
  75. Deletes class name wrap around output of given fst
  76. Args:
  77. fst: input fst
  78. Returns:
  79. Fst: fst
  80. """
  81. res = (
  82. pynutil.delete(f"{self.name}")
  83. + delete_space
  84. + pynutil.delete("{")
  85. + delete_space
  86. + fst
  87. + delete_space
  88. + pynutil.delete("}")
  89. )
  90. return res @ pynini.cdrewrite(pynini.cross(u"\u00A0", " "), "", "", FUN_SIGMA)