punctuation.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from typeguard import check_argument_types
  12. from typeguard import check_return_type
  13. from funasr.datasets.collate_fn import CommonCollateFn
  14. from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
  15. from funasr.train.abs_model import AbsPunctuation
  16. from funasr.train.abs_model import PunctuationModel
  17. from funasr.models.target_delay_transformer import TargetDelayTransformer
  18. from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
  19. from funasr.tasks.abs_task import AbsTask
  20. from funasr.text.phoneme_tokenizer import g2p_choices
  21. from funasr.torch_utils.initialize import initialize
  22. from funasr.train.class_choices import ClassChoices
  23. from funasr.train.trainer import Trainer
  24. from funasr.utils.get_default_kwargs import get_default_kwargs
  25. from funasr.utils.nested_dict_action import NestedDictAction
  26. from funasr.utils.types import str2bool
  27. from funasr.utils.types import str_or_none
  28. punc_choices = ClassChoices(
  29. "punctuation",
  30. classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
  31. type_check=AbsPunctuation,
  32. default="target_delay",
  33. )
  34. class PunctuationTask(AbsTask):
  35. # If you need more than one optimizers, change this value
  36. num_optimizers: int = 1
  37. # Add variable objects configurations
  38. class_choices_list = [punc_choices]
  39. # If you need to modify train() or eval() procedures, change Trainer class here
  40. trainer = Trainer
  41. @classmethod
  42. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  43. # NOTE(kamo): Use '_' instead of '-' to avoid confusion
  44. assert check_argument_types()
  45. group = parser.add_argument_group(description="Task related")
  46. # NOTE(kamo): add_arguments(..., required=True) can't be used
  47. # to provide --print_config mode. Instead of it, do as
  48. required = parser.get_default("required")
  49. group.add_argument(
  50. "--token_list",
  51. type=str_or_none,
  52. default=None,
  53. help="A text mapping int-id to token",
  54. )
  55. group.add_argument(
  56. "--init",
  57. type=lambda x: str_or_none(x.lower()),
  58. default=None,
  59. help="The initialization method",
  60. choices=[
  61. "chainer",
  62. "xavier_uniform",
  63. "xavier_normal",
  64. "kaiming_uniform",
  65. "kaiming_normal",
  66. None,
  67. ],
  68. )
  69. group.add_argument(
  70. "--model_conf",
  71. action=NestedDictAction,
  72. default=get_default_kwargs(PunctuationModel),
  73. help="The keyword arguments for model class.",
  74. )
  75. group = parser.add_argument_group(description="Preprocess related")
  76. group.add_argument(
  77. "--use_preprocessor",
  78. type=str2bool,
  79. default=True,
  80. help="Apply preprocessing to data or not",
  81. )
  82. group.add_argument(
  83. "--token_type",
  84. type=str,
  85. default="bpe",
  86. choices=["bpe", "char", "word"],
  87. help="",
  88. )
  89. group.add_argument(
  90. "--bpemodel",
  91. type=str_or_none,
  92. default=None,
  93. help="The model file fo sentencepiece",
  94. )
  95. parser.add_argument(
  96. "--non_linguistic_symbols",
  97. type=str_or_none,
  98. help="non_linguistic_symbols file path",
  99. )
  100. parser.add_argument(
  101. "--cleaner",
  102. type=str_or_none,
  103. choices=[None, "tacotron", "jaconv", "vietnamese"],
  104. default=None,
  105. help="Apply text cleaning",
  106. )
  107. parser.add_argument(
  108. "--g2p",
  109. type=str_or_none,
  110. choices=g2p_choices,
  111. default=None,
  112. help="Specify g2p method if --token_type=phn",
  113. )
  114. for class_choices in cls.class_choices_list:
  115. # Append --<name> and --<name>_conf.
  116. # e.g. --encoder and --encoder_conf
  117. class_choices.add_arguments(group)
  118. assert check_return_type(parser)
  119. return parser
  120. @classmethod
  121. def build_collate_fn(
  122. cls, args: argparse.Namespace, train: bool
  123. ) -> Callable[
  124. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  125. Tuple[List[str], Dict[str, torch.Tensor]],
  126. ]:
  127. assert check_argument_types()
  128. return CommonCollateFn(int_pad_value=0)
  129. @classmethod
  130. def build_preprocess_fn(
  131. cls, args: argparse.Namespace, train: bool
  132. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  133. assert check_argument_types()
  134. token_types = [args.token_type, args.token_type]
  135. token_lists = [args.token_list, args.punc_list]
  136. bpemodels = [args.bpemodel, args.bpemodel]
  137. text_names = ["text", "punc"]
  138. if args.use_preprocessor:
  139. retval = PuncTrainTokenizerCommonPreprocessor(
  140. train=train,
  141. token_type=token_types,
  142. token_list=token_lists,
  143. bpemodel=bpemodels,
  144. text_cleaner=args.cleaner,
  145. g2p_type=args.g2p,
  146. text_name = text_names,
  147. non_linguistic_symbols=args.non_linguistic_symbols,
  148. )
  149. else:
  150. retval = None
  151. assert check_return_type(retval)
  152. return retval
  153. @classmethod
  154. def required_data_names(
  155. cls, train: bool = True, inference: bool = False
  156. ) -> Tuple[str, ...]:
  157. retval = ("text", "punc")
  158. if inference:
  159. retval = ("text", )
  160. return retval
  161. @classmethod
  162. def optional_data_names(
  163. cls, train: bool = True, inference: bool = False
  164. ) -> Tuple[str, ...]:
  165. retval = ("vad",)
  166. return retval
  167. @classmethod
  168. def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
  169. assert check_argument_types()
  170. if isinstance(args.token_list, str):
  171. with open(args.token_list, encoding="utf-8") as f:
  172. token_list = [line.rstrip() for line in f]
  173. # "args" is saved as it is in a yaml file by BaseTask.main().
  174. # Overwriting token_list to keep it as "portable".
  175. args.token_list = token_list.copy()
  176. if isinstance(args.punc_list, str):
  177. with open(args.punc_list, encoding="utf-8") as f2:
  178. pairs = [line.rstrip().split(":") for line in f2]
  179. punc_list = [pair[0] for pair in pairs]
  180. punc_weight_list = [float(pair[1]) for pair in pairs]
  181. args.punc_list = punc_list.copy()
  182. elif isinstance(args.punc_list, list):
  183. punc_list = args.punc_list.copy()
  184. punc_weight_list = [1] * len(punc_list)
  185. if isinstance(args.token_list, (tuple, list)):
  186. token_list = args.token_list.copy()
  187. else:
  188. raise RuntimeError("token_list must be str or dict")
  189. vocab_size = len(token_list)
  190. punc_size = len(punc_list)
  191. logging.info(f"Vocabulary size: {vocab_size}")
  192. # 1. Build PUNC model
  193. punc_class = punc_choices.get_class(args.punctuation)
  194. punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
  195. # 2. Build ESPnetModel
  196. # Assume the last-id is sos_and_eos
  197. if "punc_weight" in args.model_conf:
  198. args.model_conf.pop("punc_weight")
  199. model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
  200. # FIXME(kamo): Should be done in model?
  201. # 3. Initialize
  202. if args.init is not None:
  203. initialize(model, args.init)
  204. assert check_return_type(model)
  205. return model