lm.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import numpy as np
  10. import torch
  11. from typeguard import check_argument_types
  12. from typeguard import check_return_type
  13. from funasr.datasets.collate_fn import CommonCollateFn
  14. from funasr.datasets.preprocessor import CommonPreprocessor
  15. from funasr.lm.abs_model import AbsLM
  16. from funasr.lm.espnet_model import ESPnetLanguageModel
  17. from funasr.lm.seq_rnn_lm import SequentialRNNLM
  18. from funasr.lm.transformer_lm import TransformerLM
  19. from funasr.tasks.abs_task import AbsTask
  20. from funasr.text.phoneme_tokenizer import g2p_choices
  21. from funasr.torch_utils.initialize import initialize
  22. from funasr.train.class_choices import ClassChoices
  23. from funasr.train.trainer import Trainer
  24. from funasr.utils.get_default_kwargs import get_default_kwargs
  25. from funasr.utils.nested_dict_action import NestedDictAction
  26. from funasr.utils.types import str2bool
  27. from funasr.utils.types import str_or_none
  28. lm_choices = ClassChoices(
  29. "lm",
  30. classes=dict(
  31. seq_rnn=SequentialRNNLM,
  32. transformer=TransformerLM,
  33. ),
  34. type_check=AbsLM,
  35. default="seq_rnn",
  36. )
  37. class LMTask(AbsTask):
  38. # If you need more than one optimizers, change this value
  39. num_optimizers: int = 1
  40. # Add variable objects configurations
  41. class_choices_list = [lm_choices]
  42. # If you need to modify train() or eval() procedures, change Trainer class here
  43. trainer = Trainer
  44. @classmethod
  45. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  46. # NOTE(kamo): Use '_' instead of '-' to avoid confusion
  47. assert check_argument_types()
  48. group = parser.add_argument_group(description="Task related")
  49. # NOTE(kamo): add_arguments(..., required=True) can't be used
  50. # to provide --print_config mode. Instead of it, do as
  51. required = parser.get_default("required")
  52. # required += ["token_list"]
  53. group.add_argument(
  54. "--token_list",
  55. type=str_or_none,
  56. default=None,
  57. help="A text mapping int-id to token",
  58. )
  59. group.add_argument(
  60. "--init",
  61. type=lambda x: str_or_none(x.lower()),
  62. default=None,
  63. help="The initialization method",
  64. choices=[
  65. "chainer",
  66. "xavier_uniform",
  67. "xavier_normal",
  68. "kaiming_uniform",
  69. "kaiming_normal",
  70. None,
  71. ],
  72. )
  73. group.add_argument(
  74. "--model_conf",
  75. action=NestedDictAction,
  76. default=get_default_kwargs(ESPnetLanguageModel),
  77. help="The keyword arguments for model class.",
  78. )
  79. group = parser.add_argument_group(description="Preprocess related")
  80. group.add_argument(
  81. "--use_preprocessor",
  82. type=str2bool,
  83. default=True,
  84. help="Apply preprocessing to data or not",
  85. )
  86. group.add_argument(
  87. "--token_type",
  88. type=str,
  89. default="bpe",
  90. choices=["bpe", "char", "word"],
  91. help="",
  92. )
  93. group.add_argument(
  94. "--bpemodel",
  95. type=str_or_none,
  96. default=None,
  97. help="The model file fo sentencepiece",
  98. )
  99. parser.add_argument(
  100. "--non_linguistic_symbols",
  101. type=str_or_none,
  102. help="non_linguistic_symbols file path",
  103. )
  104. parser.add_argument(
  105. "--cleaner",
  106. type=str_or_none,
  107. choices=[None, "tacotron", "jaconv", "vietnamese"],
  108. default=None,
  109. help="Apply text cleaning",
  110. )
  111. parser.add_argument(
  112. "--g2p",
  113. type=str_or_none,
  114. choices=g2p_choices,
  115. default=None,
  116. help="Specify g2p method if --token_type=phn",
  117. )
  118. for class_choices in cls.class_choices_list:
  119. class_choices.add_arguments(group)
  120. assert check_return_type(parser)
  121. return parser
  122. @classmethod
  123. def build_collate_fn(
  124. cls, args: argparse.Namespace, train: bool
  125. ) -> Callable[
  126. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  127. Tuple[List[str], Dict[str, torch.Tensor]],
  128. ]:
  129. assert check_argument_types()
  130. return CommonCollateFn(int_pad_value=0)
  131. @classmethod
  132. def build_preprocess_fn(
  133. cls, args: argparse.Namespace, train: bool
  134. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  135. assert check_argument_types()
  136. if args.use_preprocessor:
  137. retval = CommonPreprocessor(
  138. train=train,
  139. token_type=args.token_type,
  140. token_list=args.token_list,
  141. bpemodel=args.bpemodel,
  142. text_cleaner=args.cleaner,
  143. g2p_type=args.g2p,
  144. non_linguistic_symbols=args.non_linguistic_symbols,
  145. )
  146. else:
  147. retval = None
  148. assert check_return_type(retval)
  149. return retval
  150. @classmethod
  151. def required_data_names(
  152. cls, train: bool = True, inference: bool = False
  153. ) -> Tuple[str, ...]:
  154. retval = ("text",)
  155. return retval
  156. @classmethod
  157. def optional_data_names(
  158. cls, train: bool = True, inference: bool = False
  159. ) -> Tuple[str, ...]:
  160. retval = ()
  161. return retval
  162. @classmethod
  163. def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
  164. assert check_argument_types()
  165. if isinstance(args.token_list, str):
  166. with open(args.token_list, encoding="utf-8") as f:
  167. token_list = [line.rstrip() for line in f]
  168. # "args" is saved as it is in a yaml file by BaseTask.main().
  169. # Overwriting token_list to keep it as "portable".
  170. args.token_list = token_list.copy()
  171. elif isinstance(args.token_list, (tuple, list)):
  172. token_list = args.token_list.copy()
  173. else:
  174. raise RuntimeError("token_list must be str or dict")
  175. vocab_size = len(token_list)
  176. logging.info(f"Vocabulary size: {vocab_size}")
  177. # 1. Build LM model
  178. lm_class = lm_choices.get_class(args.lm)
  179. lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
  180. # 2. Build ESPnetModel
  181. # Assume the last-id is sos_and_eos
  182. model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
  183. # 3. Initialize
  184. if args.init is not None:
  185. initialize(model, args.init)
  186. assert check_return_type(model)
  187. return model