punctuation.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from typeguard import check_argument_types
  12. from typeguard import check_return_type
  13. from funasr.datasets.collate_fn import CommonCollateFn
  14. from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor
  15. from funasr.punctuation.abs_model import AbsPunctuation
  16. from funasr.punctuation.espnet_model import ESPnetPunctuationModel
  17. from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
  18. from funasr.tasks.abs_task import AbsTask
  19. from funasr.text.phoneme_tokenizer import g2p_choices
  20. from funasr.torch_utils.initialize import initialize
  21. from funasr.train.class_choices import ClassChoices
  22. from funasr.train.trainer import Trainer
  23. from funasr.utils.get_default_kwargs import get_default_kwargs
  24. from funasr.utils.nested_dict_action import NestedDictAction
  25. from funasr.utils.types import str2bool
  26. from funasr.utils.types import str_or_none
  27. punc_choices = ClassChoices(
  28. "punctuation",
  29. classes=dict(
  30. target_delay=TargetDelayTransformer,
  31. ),
  32. type_check=AbsPunctuation,
  33. default="TargetDelayTransformer",
  34. )
  35. class PunctuationTask(AbsTask):
  36. # If you need more than one optimizers, change this value
  37. num_optimizers: int = 1
  38. # Add variable objects configurations
  39. class_choices_list = [punc_choices]
  40. # If you need to modify train() or eval() procedures, change Trainer class here
  41. trainer = Trainer
  42. @classmethod
  43. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  44. # NOTE(kamo): Use '_' instead of '-' to avoid confusion
  45. assert check_argument_types()
  46. group = parser.add_argument_group(description="Task related")
  47. # NOTE(kamo): add_arguments(..., required=True) can't be used
  48. # to provide --print_config mode. Instead of it, do as
  49. required = parser.get_default("required")
  50. #import pdb;pdb.set_trace()
  51. #required += ["token_list"]
  52. group.add_argument(
  53. "--token_list",
  54. type=str_or_none,
  55. default=None,
  56. help="A text mapping int-id to token",
  57. )
  58. group.add_argument(
  59. "--init",
  60. type=lambda x: str_or_none(x.lower()),
  61. default=None,
  62. help="The initialization method",
  63. choices=[
  64. "chainer",
  65. "xavier_uniform",
  66. "xavier_normal",
  67. "kaiming_uniform",
  68. "kaiming_normal",
  69. None,
  70. ],
  71. )
  72. group.add_argument(
  73. "--model_conf",
  74. action=NestedDictAction,
  75. default=get_default_kwargs(ESPnetPunctuationModel),
  76. help="The keyword arguments for model class.",
  77. )
  78. group = parser.add_argument_group(description="Preprocess related")
  79. group.add_argument(
  80. "--use_preprocessor",
  81. type=str2bool,
  82. default=True,
  83. help="Apply preprocessing to data or not",
  84. )
  85. group.add_argument(
  86. "--token_type",
  87. type=str,
  88. default="bpe",
  89. choices=["bpe", "char", "word"],
  90. help="",
  91. )
  92. group.add_argument(
  93. "--bpemodel",
  94. type=str_or_none,
  95. default=None,
  96. help="The model file fo sentencepiece",
  97. )
  98. parser.add_argument(
  99. "--non_linguistic_symbols",
  100. type=str_or_none,
  101. help="non_linguistic_symbols file path",
  102. )
  103. parser.add_argument(
  104. "--cleaner",
  105. type=str_or_none,
  106. choices=[None, "tacotron", "jaconv", "vietnamese"],
  107. default=None,
  108. help="Apply text cleaning",
  109. )
  110. parser.add_argument(
  111. "--g2p",
  112. type=str_or_none,
  113. choices=g2p_choices,
  114. default=None,
  115. help="Specify g2p method if --token_type=phn",
  116. )
  117. for class_choices in cls.class_choices_list:
  118. # Append --<name> and --<name>_conf.
  119. # e.g. --encoder and --encoder_conf
  120. class_choices.add_arguments(group)
  121. assert check_return_type(parser)
  122. return parser
  123. @classmethod
  124. def build_collate_fn(
  125. cls, args: argparse.Namespace, train: bool
  126. ) -> Callable[
  127. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  128. Tuple[List[str], Dict[str, torch.Tensor]],
  129. ]:
  130. assert check_argument_types()
  131. return CommonCollateFn(int_pad_value=0)
  132. @classmethod
  133. def build_preprocess_fn(
  134. cls, args: argparse.Namespace, train: bool
  135. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  136. assert check_argument_types()
  137. token_types = [args.token_type, args.token_type]
  138. token_lists = [args.token_list, args.punc_list]
  139. bpemodels = [args.bpemodel, args.bpemodel]
  140. text_names = ["text", "punc"]
  141. if args.use_preprocessor:
  142. retval = MutliTokenizerCommonPreprocessor(
  143. train=train,
  144. token_type=token_types,
  145. token_list=token_lists,
  146. bpemodel=bpemodels,
  147. text_cleaner=args.cleaner,
  148. g2p_type=args.g2p,
  149. text_name = text_names,
  150. non_linguistic_symbols=args.non_linguistic_symbols,
  151. )
  152. else:
  153. retval = None
  154. assert check_return_type(retval)
  155. return retval
  156. @classmethod
  157. def required_data_names(
  158. cls, train: bool = True, inference: bool = False
  159. ) -> Tuple[str, ...]:
  160. retval = ("text", "punc")
  161. if inference:
  162. retval = ("text", )
  163. return retval
  164. @classmethod
  165. def optional_data_names(
  166. cls, train: bool = True, inference: bool = False
  167. ) -> Tuple[str, ...]:
  168. retval = ()
  169. return retval
  170. @classmethod
  171. def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
  172. assert check_argument_types()
  173. if isinstance(args.token_list, str):
  174. with open(args.token_list, encoding="utf-8") as f:
  175. token_list = [line.rstrip() for line in f]
  176. # "args" is saved as it is in a yaml file by BaseTask.main().
  177. # Overwriting token_list to keep it as "portable".
  178. args.token_list = token_list.copy()
  179. if isinstance(args.punc_list, str):
  180. with open(args.punc_list, encoding="utf-8") as f2:
  181. punc_list = [line.rstrip() for line in f2]
  182. args.punc_list = punc_list.copy()
  183. elif isinstance(args.punc_list, list):
  184. # This is in the inference code path.
  185. punc_list = args.punc_list.copy()
  186. if isinstance(args.token_list, (tuple, list)):
  187. token_list = args.token_list.copy()
  188. else:
  189. raise RuntimeError("token_list must be str or dict")
  190. vocab_size = len(token_list)
  191. punc_size = len(punc_list)
  192. logging.info(f"Vocabulary size: {vocab_size}")
  193. # 1. Build PUNC model
  194. punc_class = punc_choices.get_class(args.punctuation)
  195. punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
  196. # 2. Build ESPnetModel
  197. # Assume the last-id is sos_and_eos
  198. model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf)
  199. # FIXME(kamo): Should be done in model?
  200. # 3. Initialize
  201. if args.init is not None:
  202. initialize(model, args.init)
  203. assert check_return_type(model)
  204. return model