punctuation.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from funasr.datasets.collate_fn import CommonCollateFn
  12. from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
  13. from funasr.train.abs_model import PunctuationModel
  14. from funasr.models.target_delay_transformer import TargetDelayTransformer
  15. from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
  16. from funasr.tasks.abs_task import AbsTask
  17. from funasr.tokenizer.phoneme_tokenizer import g2p_choices
  18. from funasr.torch_utils.initialize import initialize
  19. from funasr.train.class_choices import ClassChoices
  20. from funasr.train.trainer import Trainer
  21. from funasr.utils.get_default_kwargs import get_default_kwargs
  22. from funasr.utils.nested_dict_action import NestedDictAction
  23. from funasr.utils.types import str2bool
  24. from funasr.utils.types import str_or_none
  25. punc_choices = ClassChoices(
  26. "punctuation",
  27. classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
  28. default="target_delay",
  29. )
  30. class PunctuationTask(AbsTask):
  31. # If you need more than one optimizers, change this value
  32. num_optimizers: int = 1
  33. # Add variable objects configurations
  34. class_choices_list = [punc_choices]
  35. # If you need to modify train() or eval() procedures, change Trainer class here
  36. trainer = Trainer
  37. @classmethod
  38. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  39. # NOTE(kamo): Use '_' instead of '-' to avoid confusion
  40. group = parser.add_argument_group(description="Task related")
  41. # NOTE(kamo): add_arguments(..., required=True) can't be used
  42. # to provide --print_config mode. Instead of it, do as
  43. required = parser.get_default("required")
  44. group.add_argument(
  45. "--token_list",
  46. type=str_or_none,
  47. default=None,
  48. help="A text mapping int-id to token",
  49. )
  50. group.add_argument(
  51. "--init",
  52. type=lambda x: str_or_none(x.lower()),
  53. default=None,
  54. help="The initialization method",
  55. choices=[
  56. "chainer",
  57. "xavier_uniform",
  58. "xavier_normal",
  59. "kaiming_uniform",
  60. "kaiming_normal",
  61. None,
  62. ],
  63. )
  64. group.add_argument(
  65. "--model_conf",
  66. action=NestedDictAction,
  67. default=get_default_kwargs(PunctuationModel),
  68. help="The keyword arguments for model class.",
  69. )
  70. group = parser.add_argument_group(description="Preprocess related")
  71. group.add_argument(
  72. "--use_preprocessor",
  73. type=str2bool,
  74. default=True,
  75. help="Apply preprocessing to data or not",
  76. )
  77. group.add_argument(
  78. "--token_type",
  79. type=str,
  80. default="bpe",
  81. choices=["bpe", "char", "word"],
  82. help="",
  83. )
  84. group.add_argument(
  85. "--bpemodel",
  86. type=str_or_none,
  87. default=None,
  88. help="The model file fo sentencepiece",
  89. )
  90. parser.add_argument(
  91. "--non_linguistic_symbols",
  92. type=str_or_none,
  93. help="non_linguistic_symbols file path",
  94. )
  95. parser.add_argument(
  96. "--cleaner",
  97. type=str_or_none,
  98. choices=[None, "tacotron", "jaconv", "vietnamese"],
  99. default=None,
  100. help="Apply text cleaning",
  101. )
  102. parser.add_argument(
  103. "--g2p",
  104. type=str_or_none,
  105. choices=g2p_choices,
  106. default=None,
  107. help="Specify g2p method if --token_type=phn",
  108. )
  109. for class_choices in cls.class_choices_list:
  110. # Append --<name> and --<name>_conf.
  111. # e.g. --encoder and --encoder_conf
  112. class_choices.add_arguments(group)
  113. return parser
  114. @classmethod
  115. def build_collate_fn(
  116. cls, args: argparse.Namespace, train: bool
  117. ) -> Callable[
  118. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  119. Tuple[List[str], Dict[str, torch.Tensor]],
  120. ]:
  121. return CommonCollateFn(int_pad_value=0)
  122. @classmethod
  123. def build_preprocess_fn(
  124. cls, args: argparse.Namespace, train: bool
  125. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  126. token_types = [args.token_type, args.token_type]
  127. token_lists = [args.token_list, args.punc_list]
  128. bpemodels = [args.bpemodel, args.bpemodel]
  129. text_names = ["text", "punc"]
  130. if args.use_preprocessor:
  131. retval = PuncTrainTokenizerCommonPreprocessor(
  132. train=train,
  133. token_type=token_types,
  134. token_list=token_lists,
  135. bpemodel=bpemodels,
  136. text_cleaner=args.cleaner,
  137. g2p_type=args.g2p,
  138. text_name = text_names,
  139. non_linguistic_symbols=args.non_linguistic_symbols,
  140. )
  141. else:
  142. retval = None
  143. return retval
  144. @classmethod
  145. def required_data_names(
  146. cls, train: bool = True, inference: bool = False
  147. ) -> Tuple[str, ...]:
  148. retval = ("text", "punc")
  149. if inference:
  150. retval = ("text", )
  151. return retval
  152. @classmethod
  153. def optional_data_names(
  154. cls, train: bool = True, inference: bool = False
  155. ) -> Tuple[str, ...]:
  156. retval = ("vad",)
  157. return retval
  158. @classmethod
  159. def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
  160. if isinstance(args.token_list, str):
  161. with open(args.token_list, encoding="utf-8") as f:
  162. token_list = [line.rstrip() for line in f]
  163. # "args" is saved as it is in a yaml file by BaseTask.main().
  164. # Overwriting token_list to keep it as "portable".
  165. args.token_list = token_list.copy()
  166. if isinstance(args.punc_list, str):
  167. with open(args.punc_list, encoding="utf-8") as f2:
  168. pairs = [line.rstrip().split(":") for line in f2]
  169. punc_list = [pair[0] for pair in pairs]
  170. punc_weight_list = [float(pair[1]) for pair in pairs]
  171. args.punc_list = punc_list.copy()
  172. elif isinstance(args.punc_list, list):
  173. punc_list = args.punc_list.copy()
  174. punc_weight_list = [1] * len(punc_list)
  175. if isinstance(args.token_list, (tuple, list)):
  176. token_list = args.token_list.copy()
  177. else:
  178. raise RuntimeError("token_list must be str or dict")
  179. vocab_size = len(token_list)
  180. punc_size = len(punc_list)
  181. logging.info(f"Vocabulary size: {vocab_size}")
  182. # 1. Build PUNC model
  183. punc_class = punc_choices.get_class(args.punctuation)
  184. punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
  185. # 2. Build ESPnetModel
  186. # Assume the last-id is sos_and_eos
  187. if "punc_weight" in args.model_conf:
  188. args.model_conf.pop("punc_weight")
  189. model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
  190. # FIXME(kamo): Should be done in model?
  191. # 3. Initialize
  192. if args.init is not None:
  193. initialize(model, args.init)
  194. return model