lm.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from funasr.datasets.collate_fn import CommonCollateFn
  12. from funasr.datasets.preprocessor import CommonPreprocessor
  13. from funasr.train.abs_model import AbsLM
  14. from funasr.train.abs_model import LanguageModel
  15. from funasr.models.seq_rnn_lm import SequentialRNNLM
  16. from funasr.models.transformer_lm import TransformerLM
  17. from funasr.tasks.abs_task import AbsTask
  18. from funasr.text.phoneme_tokenizer import g2p_choices
  19. from funasr.torch_utils.initialize import initialize
  20. from funasr.train.class_choices import ClassChoices
  21. from funasr.train.trainer import Trainer
  22. from funasr.utils.get_default_kwargs import get_default_kwargs
  23. from funasr.utils.nested_dict_action import NestedDictAction
  24. from funasr.utils.types import str2bool
  25. from funasr.utils.types import str_or_none
  26. lm_choices = ClassChoices(
  27. "lm",
  28. classes=dict(
  29. seq_rnn=SequentialRNNLM,
  30. transformer=TransformerLM,
  31. ),
  32. type_check=AbsLM,
  33. default="seq_rnn",
  34. )
  35. class LMTask(AbsTask):
  36. # If you need more than one optimizers, change this value
  37. num_optimizers: int = 1
  38. # Add variable objects configurations
  39. class_choices_list = [lm_choices]
  40. # If you need to modify train() or eval() procedures, change Trainer class here
  41. trainer = Trainer
  42. @classmethod
  43. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  44. # NOTE(kamo): Use '_' instead of '-' to avoid confusion
  45. group = parser.add_argument_group(description="Task related")
  46. # NOTE(kamo): add_arguments(..., required=True) can't be used
  47. # to provide --print_config mode. Instead of it, do as
  48. required = parser.get_default("required")
  49. # required += ["token_list"]
  50. group.add_argument(
  51. "--token_list",
  52. type=str_or_none,
  53. default=None,
  54. help="A text mapping int-id to token",
  55. )
  56. group.add_argument(
  57. "--init",
  58. type=lambda x: str_or_none(x.lower()),
  59. default=None,
  60. help="The initialization method",
  61. choices=[
  62. "chainer",
  63. "xavier_uniform",
  64. "xavier_normal",
  65. "kaiming_uniform",
  66. "kaiming_normal",
  67. None,
  68. ],
  69. )
  70. group.add_argument(
  71. "--model_conf",
  72. action=NestedDictAction,
  73. default=get_default_kwargs(LanguageModel),
  74. help="The keyword arguments for model class.",
  75. )
  76. group = parser.add_argument_group(description="Preprocess related")
  77. group.add_argument(
  78. "--use_preprocessor",
  79. type=str2bool,
  80. default=True,
  81. help="Apply preprocessing to data or not",
  82. )
  83. group.add_argument(
  84. "--token_type",
  85. type=str,
  86. default="bpe",
  87. choices=["bpe", "char", "word"],
  88. help="",
  89. )
  90. group.add_argument(
  91. "--bpemodel",
  92. type=str_or_none,
  93. default=None,
  94. help="The model file fo sentencepiece",
  95. )
  96. parser.add_argument(
  97. "--non_linguistic_symbols",
  98. type=str_or_none,
  99. help="non_linguistic_symbols file path",
  100. )
  101. parser.add_argument(
  102. "--cleaner",
  103. type=str_or_none,
  104. choices=[None, "tacotron", "jaconv", "vietnamese"],
  105. default=None,
  106. help="Apply text cleaning",
  107. )
  108. parser.add_argument(
  109. "--g2p",
  110. type=str_or_none,
  111. choices=g2p_choices,
  112. default=None,
  113. help="Specify g2p method if --token_type=phn",
  114. )
  115. for class_choices in cls.class_choices_list:
  116. class_choices.add_arguments(group)
  117. return parser
  118. @classmethod
  119. def build_collate_fn(
  120. cls, args: argparse.Namespace, train: bool
  121. ) -> Callable[
  122. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  123. Tuple[List[str], Dict[str, torch.Tensor]],
  124. ]:
  125. return CommonCollateFn(int_pad_value=0)
  126. @classmethod
  127. def build_preprocess_fn(
  128. cls, args: argparse.Namespace, train: bool
  129. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  130. if args.use_preprocessor:
  131. retval = CommonPreprocessor(
  132. train=train,
  133. token_type=args.token_type,
  134. token_list=args.token_list,
  135. bpemodel=args.bpemodel,
  136. text_cleaner=args.cleaner,
  137. g2p_type=args.g2p,
  138. non_linguistic_symbols=args.non_linguistic_symbols,
  139. )
  140. else:
  141. retval = None
  142. return retval
  143. @classmethod
  144. def required_data_names(
  145. cls, train: bool = True, inference: bool = False
  146. ) -> Tuple[str, ...]:
  147. retval = ("text",)
  148. return retval
  149. @classmethod
  150. def optional_data_names(
  151. cls, train: bool = True, inference: bool = False
  152. ) -> Tuple[str, ...]:
  153. retval = ()
  154. return retval
  155. @classmethod
  156. def build_model(cls, args: argparse.Namespace) -> LanguageModel:
  157. if isinstance(args.token_list, str):
  158. with open(args.token_list, encoding="utf-8") as f:
  159. token_list = [line.rstrip() for line in f]
  160. # "args" is saved as it is in a yaml file by BaseTask.main().
  161. # Overwriting token_list to keep it as "portable".
  162. args.token_list = token_list.copy()
  163. elif isinstance(args.token_list, (tuple, list)):
  164. token_list = args.token_list.copy()
  165. else:
  166. raise RuntimeError("token_list must be str or dict")
  167. vocab_size = len(token_list)
  168. logging.info(f"Vocabulary size: {vocab_size}")
  169. # 1. Build LM model
  170. lm_class = lm_choices.get_class(args.lm)
  171. lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
  172. # 2. Build ESPnetModel
  173. # Assume the last-id is sos_and_eos
  174. model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
  175. # 3. Initialize
  176. if args.init is not None:
  177. initialize(model, args.init)
  178. return model