build_sv_model.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. import logging
  2. import torch
  3. from funasr.layers.abs_normalize import AbsNormalize
  4. from funasr.layers.global_mvn import GlobalMVN
  5. from funasr.layers.utterance_mvn import UtteranceMVN
  6. from funasr.models.base_model import FunASRModel
  7. from funasr.models.decoder.abs_decoder import AbsDecoder
  8. from funasr.models.decoder.sv_decoder import DenseDecoder
  9. from funasr.models.e2e_sv import ESPnetSVModel
  10. from funasr.models.encoder.abs_encoder import AbsEncoder
  11. from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
  12. from funasr.models.encoder.rnn_encoder import RNNEncoder
  13. from funasr.models.frontend.abs_frontend import AbsFrontend
  14. from funasr.models.frontend.default import DefaultFrontend
  15. from funasr.models.frontend.fused import FusedFrontends
  16. from funasr.models.frontend.s3prl import S3prlFrontend
  17. from funasr.models.frontend.wav_frontend import WavFrontend
  18. from funasr.models.frontend.windowing import SlidingWindow
  19. from funasr.models.pooling.statistic_pooling import StatisticPooling
  20. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  21. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  22. HuggingFaceTransformersPostEncoder, # noqa: H301
  23. )
  24. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  25. from funasr.models.preencoder.linear import LinearProjection
  26. from funasr.models.preencoder.sinc import LightweightSincConvs
  27. from funasr.models.specaug.abs_specaug import AbsSpecAug
  28. from funasr.models.specaug.specaug import SpecAug
  29. from funasr.torch_utils.initialize import initialize
  30. from funasr.train.class_choices import ClassChoices
  31. frontend_choices = ClassChoices(
  32. name="frontend",
  33. classes=dict(
  34. default=DefaultFrontend,
  35. sliding_window=SlidingWindow,
  36. s3prl=S3prlFrontend,
  37. fused=FusedFrontends,
  38. wav_frontend=WavFrontend,
  39. ),
  40. type_check=AbsFrontend,
  41. default="default",
  42. )
  43. specaug_choices = ClassChoices(
  44. name="specaug",
  45. classes=dict(
  46. specaug=SpecAug,
  47. ),
  48. type_check=AbsSpecAug,
  49. default=None,
  50. optional=True,
  51. )
  52. normalize_choices = ClassChoices(
  53. "normalize",
  54. classes=dict(
  55. global_mvn=GlobalMVN,
  56. utterance_mvn=UtteranceMVN,
  57. ),
  58. type_check=AbsNormalize,
  59. default=None,
  60. optional=True,
  61. )
  62. model_choices = ClassChoices(
  63. "model",
  64. classes=dict(
  65. espnet=ESPnetSVModel,
  66. ),
  67. type_check=FunASRModel,
  68. default="espnet",
  69. )
  70. preencoder_choices = ClassChoices(
  71. name="preencoder",
  72. classes=dict(
  73. sinc=LightweightSincConvs,
  74. linear=LinearProjection,
  75. ),
  76. type_check=AbsPreEncoder,
  77. default=None,
  78. optional=True,
  79. )
  80. encoder_choices = ClassChoices(
  81. "encoder",
  82. classes=dict(
  83. resnet34=ResNet34,
  84. resnet34_sp_l2reg=ResNet34_SP_L2Reg,
  85. rnn=RNNEncoder,
  86. ),
  87. type_check=AbsEncoder,
  88. default="resnet34",
  89. )
  90. postencoder_choices = ClassChoices(
  91. name="postencoder",
  92. classes=dict(
  93. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  94. ),
  95. type_check=AbsPostEncoder,
  96. default=None,
  97. optional=True,
  98. )
  99. pooling_choices = ClassChoices(
  100. name="pooling_type",
  101. classes=dict(
  102. statistic=StatisticPooling,
  103. ),
  104. type_check=torch.nn.Module,
  105. default="statistic",
  106. )
  107. decoder_choices = ClassChoices(
  108. "decoder",
  109. classes=dict(
  110. dense=DenseDecoder,
  111. ),
  112. type_check=AbsDecoder,
  113. default="dense",
  114. )
  115. class_choices_list = [
  116. # --frontend and --frontend_conf
  117. frontend_choices,
  118. # --specaug and --specaug_conf
  119. specaug_choices,
  120. # --normalize and --normalize_conf
  121. normalize_choices,
  122. # --model and --model_conf
  123. model_choices,
  124. # --preencoder and --preencoder_conf
  125. preencoder_choices,
  126. # --encoder and --encoder_conf
  127. encoder_choices,
  128. # --postencoder and --postencoder_conf
  129. postencoder_choices,
  130. # --pooling and --pooling_conf
  131. pooling_choices,
  132. # --decoder and --decoder_conf
  133. decoder_choices,
  134. ]
  135. def build_sv_model(args):
  136. # token_list
  137. if isinstance(args.token_list, str):
  138. with open(args.token_list, encoding="utf-8") as f:
  139. token_list = [line.rstrip() for line in f]
  140. # Overwriting token_list to keep it as "portable".
  141. args.token_list = list(token_list)
  142. elif isinstance(args.token_list, (tuple, list)):
  143. token_list = list(args.token_list)
  144. else:
  145. raise RuntimeError("token_list must be str or list")
  146. vocab_size = len(token_list)
  147. logging.info(f"Speaker number: {vocab_size}")
  148. # 1. frontend
  149. if args.input_size is None:
  150. # Extract features in the model
  151. frontend_class = frontend_choices.get_class(args.frontend)
  152. frontend = frontend_class(**args.frontend_conf)
  153. input_size = frontend.output_size()
  154. else:
  155. # Give features from data-loader
  156. args.frontend = None
  157. args.frontend_conf = {}
  158. frontend = None
  159. input_size = args.input_size
  160. # 2. Data augmentation for spectrogram
  161. if args.specaug is not None:
  162. specaug_class = specaug_choices.get_class(args.specaug)
  163. specaug = specaug_class(**args.specaug_conf)
  164. else:
  165. specaug = None
  166. # 3. Normalization layer
  167. if args.normalize is not None:
  168. normalize_class = normalize_choices.get_class(args.normalize)
  169. normalize = normalize_class(**args.normalize_conf)
  170. else:
  171. normalize = None
  172. # 4. Pre-encoder input block
  173. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  174. if getattr(args, "preencoder", None) is not None:
  175. preencoder_class = preencoder_choices.get_class(args.preencoder)
  176. preencoder = preencoder_class(**args.preencoder_conf)
  177. input_size = preencoder.output_size()
  178. else:
  179. preencoder = None
  180. # 5. Encoder
  181. encoder_class = encoder_choices.get_class(args.encoder)
  182. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  183. # 6. Post-encoder block
  184. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  185. encoder_output_size = encoder.output_size()
  186. if getattr(args, "postencoder", None) is not None:
  187. postencoder_class = postencoder_choices.get_class(args.postencoder)
  188. postencoder = postencoder_class(
  189. input_size=encoder_output_size, **args.postencoder_conf
  190. )
  191. encoder_output_size = postencoder.output_size()
  192. else:
  193. postencoder = None
  194. # 7. Pooling layer
  195. pooling_class = pooling_choices.get_class(args.pooling_type)
  196. pooling_dim = (2, 3)
  197. eps = 1e-12
  198. if hasattr(args, "pooling_type_conf"):
  199. if "pooling_dim" in args.pooling_type_conf:
  200. pooling_dim = args.pooling_type_conf["pooling_dim"]
  201. if "eps" in args.pooling_type_conf:
  202. eps = args.pooling_type_conf["eps"]
  203. pooling_layer = pooling_class(
  204. pooling_dim=pooling_dim,
  205. eps=eps,
  206. )
  207. if args.pooling_type == "statistic":
  208. encoder_output_size *= 2
  209. # 8. Decoder
  210. decoder_class = decoder_choices.get_class(args.decoder)
  211. decoder = decoder_class(
  212. vocab_size=vocab_size,
  213. encoder_output_size=encoder_output_size,
  214. **args.decoder_conf,
  215. )
  216. # 7. Build model
  217. try:
  218. model_class = model_choices.get_class(args.model)
  219. except AttributeError:
  220. model_class = model_choices.get_class("espnet")
  221. model = model_class(
  222. vocab_size=vocab_size,
  223. token_list=token_list,
  224. frontend=frontend,
  225. specaug=specaug,
  226. normalize=normalize,
  227. preencoder=preencoder,
  228. encoder=encoder,
  229. postencoder=postencoder,
  230. pooling_layer=pooling_layer,
  231. decoder=decoder,
  232. **args.model_conf,
  233. )
  234. # FIXME(kamo): Should be done in model?
  235. # 8. Initialize
  236. if args.init is not None:
  237. initialize(model, args.init)
  238. return model