verbalize_final.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import os
  2. import pynini
  3. from fun_text_processing.text_normalization.zh.graph_utils import GraphFst, delete_space, generator_main
  4. from fun_text_processing.text_normalization.zh.verbalizers.postprocessor import PostProcessor
  5. from fun_text_processing.text_normalization.zh.verbalizers.verbalize import VerbalizeFst
  6. from pynini.lib import pynutil
  7. # import logging
  8. class VerbalizeFinalFst(GraphFst):
  9. """
  10. """
  11. def __init__(self, deterministic: bool = True, cache_dir: str = None, overwrite_cache: bool = False):
  12. super().__init__(name="verbalize_final", kind="verbalize", deterministic=deterministic)
  13. far_file = None
  14. if cache_dir is not None and cache_dir != "None":
  15. os.makedirs(cache_dir, exist_ok=True)
  16. far_file = os.path.join(cache_dir, f"zh_tn_{deterministic}_deterministic_verbalizer.far")
  17. if not overwrite_cache and far_file and os.path.exists(far_file):
  18. self.fst = pynini.Far(far_file, mode="r")["verbalize"]
  19. else:
  20. token_graph = VerbalizeFst(deterministic=deterministic)
  21. token_verbalizer = (
  22. pynutil.delete("tokens {") + delete_space + token_graph.fst + delete_space + pynutil.delete(" }")
  23. )
  24. verbalizer = pynini.closure(delete_space + token_verbalizer + delete_space)
  25. postprocessor = PostProcessor(remove_puncts=False, to_upper=False, to_lower=False, tag_oov=False,)
  26. self.fst = (verbalizer @ postprocessor.fst).optimize()
  27. if far_file:
  28. generator_main(far_file, {"verbalize": self.fst})