build_sv_model.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import logging
  2. import torch
  3. from typeguard import check_return_type
  4. from funasr.layers.abs_normalize import AbsNormalize
  5. from funasr.layers.global_mvn import GlobalMVN
  6. from funasr.layers.utterance_mvn import UtteranceMVN
  7. from funasr.models.base_model import FunASRModel
  8. from funasr.models.decoder.abs_decoder import AbsDecoder
  9. from funasr.models.decoder.sv_decoder import DenseDecoder
  10. from funasr.models.e2e_sv import ESPnetSVModel
  11. from funasr.models.encoder.abs_encoder import AbsEncoder
  12. from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
  13. from funasr.models.encoder.rnn_encoder import RNNEncoder
  14. from funasr.models.frontend.abs_frontend import AbsFrontend
  15. from funasr.models.frontend.default import DefaultFrontend
  16. from funasr.models.frontend.fused import FusedFrontends
  17. from funasr.models.frontend.s3prl import S3prlFrontend
  18. from funasr.models.frontend.wav_frontend import WavFrontend
  19. from funasr.models.frontend.windowing import SlidingWindow
  20. from funasr.models.pooling.statistic_pooling import StatisticPooling
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  23. HuggingFaceTransformersPostEncoder, # noqa: H301
  24. )
  25. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  26. from funasr.models.preencoder.linear import LinearProjection
  27. from funasr.models.preencoder.sinc import LightweightSincConvs
  28. from funasr.models.specaug.abs_specaug import AbsSpecAug
  29. from funasr.models.specaug.specaug import SpecAug
  30. from funasr.torch_utils.initialize import initialize
  31. from funasr.train.class_choices import ClassChoices
  32. frontend_choices = ClassChoices(
  33. name="frontend",
  34. classes=dict(
  35. default=DefaultFrontend,
  36. sliding_window=SlidingWindow,
  37. s3prl=S3prlFrontend,
  38. fused=FusedFrontends,
  39. wav_frontend=WavFrontend,
  40. ),
  41. type_check=AbsFrontend,
  42. default="default",
  43. )
  44. specaug_choices = ClassChoices(
  45. name="specaug",
  46. classes=dict(
  47. specaug=SpecAug,
  48. ),
  49. type_check=AbsSpecAug,
  50. default=None,
  51. optional=True,
  52. )
  53. normalize_choices = ClassChoices(
  54. "normalize",
  55. classes=dict(
  56. global_mvn=GlobalMVN,
  57. utterance_mvn=UtteranceMVN,
  58. ),
  59. type_check=AbsNormalize,
  60. default=None,
  61. optional=True,
  62. )
  63. model_choices = ClassChoices(
  64. "model",
  65. classes=dict(
  66. espnet=ESPnetSVModel,
  67. ),
  68. type_check=FunASRModel,
  69. default="espnet",
  70. )
  71. preencoder_choices = ClassChoices(
  72. name="preencoder",
  73. classes=dict(
  74. sinc=LightweightSincConvs,
  75. linear=LinearProjection,
  76. ),
  77. type_check=AbsPreEncoder,
  78. default=None,
  79. optional=True,
  80. )
  81. encoder_choices = ClassChoices(
  82. "encoder",
  83. classes=dict(
  84. resnet34=ResNet34,
  85. resnet34_sp_l2reg=ResNet34_SP_L2Reg,
  86. rnn=RNNEncoder,
  87. ),
  88. type_check=AbsEncoder,
  89. default="resnet34",
  90. )
  91. postencoder_choices = ClassChoices(
  92. name="postencoder",
  93. classes=dict(
  94. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  95. ),
  96. type_check=AbsPostEncoder,
  97. default=None,
  98. optional=True,
  99. )
  100. pooling_choices = ClassChoices(
  101. name="pooling_type",
  102. classes=dict(
  103. statistic=StatisticPooling,
  104. ),
  105. type_check=torch.nn.Module,
  106. default="statistic",
  107. )
  108. decoder_choices = ClassChoices(
  109. "decoder",
  110. classes=dict(
  111. dense=DenseDecoder,
  112. ),
  113. type_check=AbsDecoder,
  114. default="dense",
  115. )
  116. class_choices_list = [
  117. # --frontend and --frontend_conf
  118. frontend_choices,
  119. # --specaug and --specaug_conf
  120. specaug_choices,
  121. # --normalize and --normalize_conf
  122. normalize_choices,
  123. # --model and --model_conf
  124. model_choices,
  125. # --preencoder and --preencoder_conf
  126. preencoder_choices,
  127. # --encoder and --encoder_conf
  128. encoder_choices,
  129. # --postencoder and --postencoder_conf
  130. postencoder_choices,
  131. # --pooling and --pooling_conf
  132. pooling_choices,
  133. # --decoder and --decoder_conf
  134. decoder_choices,
  135. ]
  136. def build_sv_model(args):
  137. # token_list
  138. if isinstance(args.token_list, str):
  139. with open(args.token_list, encoding="utf-8") as f:
  140. token_list = [line.rstrip() for line in f]
  141. # Overwriting token_list to keep it as "portable".
  142. args.token_list = list(token_list)
  143. elif isinstance(args.token_list, (tuple, list)):
  144. token_list = list(args.token_list)
  145. else:
  146. raise RuntimeError("token_list must be str or list")
  147. vocab_size = len(token_list)
  148. logging.info(f"Speaker number: {vocab_size}")
  149. # 1. frontend
  150. if args.input_size is None:
  151. # Extract features in the model
  152. frontend_class = frontend_choices.get_class(args.frontend)
  153. frontend = frontend_class(**args.frontend_conf)
  154. input_size = frontend.output_size()
  155. else:
  156. # Give features from data-loader
  157. args.frontend = None
  158. args.frontend_conf = {}
  159. frontend = None
  160. input_size = args.input_size
  161. # 2. Data augmentation for spectrogram
  162. if args.specaug is not None:
  163. specaug_class = specaug_choices.get_class(args.specaug)
  164. specaug = specaug_class(**args.specaug_conf)
  165. else:
  166. specaug = None
  167. # 3. Normalization layer
  168. if args.normalize is not None:
  169. normalize_class = normalize_choices.get_class(args.normalize)
  170. normalize = normalize_class(**args.normalize_conf)
  171. else:
  172. normalize = None
  173. # 4. Pre-encoder input block
  174. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  175. if getattr(args, "preencoder", None) is not None:
  176. preencoder_class = preencoder_choices.get_class(args.preencoder)
  177. preencoder = preencoder_class(**args.preencoder_conf)
  178. input_size = preencoder.output_size()
  179. else:
  180. preencoder = None
  181. # 5. Encoder
  182. encoder_class = encoder_choices.get_class(args.encoder)
  183. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  184. # 6. Post-encoder block
  185. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  186. encoder_output_size = encoder.output_size()
  187. if getattr(args, "postencoder", None) is not None:
  188. postencoder_class = postencoder_choices.get_class(args.postencoder)
  189. postencoder = postencoder_class(
  190. input_size=encoder_output_size, **args.postencoder_conf
  191. )
  192. encoder_output_size = postencoder.output_size()
  193. else:
  194. postencoder = None
  195. # 7. Pooling layer
  196. pooling_class = pooling_choices.get_class(args.pooling_type)
  197. pooling_dim = (2, 3)
  198. eps = 1e-12
  199. if hasattr(args, "pooling_type_conf"):
  200. if "pooling_dim" in args.pooling_type_conf:
  201. pooling_dim = args.pooling_type_conf["pooling_dim"]
  202. if "eps" in args.pooling_type_conf:
  203. eps = args.pooling_type_conf["eps"]
  204. pooling_layer = pooling_class(
  205. pooling_dim=pooling_dim,
  206. eps=eps,
  207. )
  208. if args.pooling_type == "statistic":
  209. encoder_output_size *= 2
  210. # 8. Decoder
  211. decoder_class = decoder_choices.get_class(args.decoder)
  212. decoder = decoder_class(
  213. vocab_size=vocab_size,
  214. encoder_output_size=encoder_output_size,
  215. **args.decoder_conf,
  216. )
  217. # 7. Build model
  218. try:
  219. model_class = model_choices.get_class(args.model)
  220. except AttributeError:
  221. model_class = model_choices.get_class("espnet")
  222. model = model_class(
  223. vocab_size=vocab_size,
  224. token_list=token_list,
  225. frontend=frontend,
  226. specaug=specaug,
  227. normalize=normalize,
  228. preencoder=preencoder,
  229. encoder=encoder,
  230. postencoder=postencoder,
  231. pooling_layer=pooling_layer,
  232. decoder=decoder,
  233. **args.model_conf,
  234. )
  235. # FIXME(kamo): Should be done in model?
  236. # 8. Initialize
  237. if args.init is not None:
  238. initialize(model, args.init)
  239. assert check_return_type(model)
  240. return model