build_diar_model.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. import logging
  2. import torch
  3. from funasr.layers.global_mvn import GlobalMVN
  4. from funasr.layers.label_aggregation import LabelAggregate
  5. from funasr.layers.utterance_mvn import UtteranceMVN
  6. from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
  7. from funasr.models.e2e_diar_sond import DiarSondModel
  8. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  9. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  10. from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
  11. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  12. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  13. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  14. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  15. from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
  16. from funasr.models.encoder.rnn_encoder import RNNEncoder
  17. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  18. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  19. from funasr.models.frontend.default import DefaultFrontend
  20. from funasr.models.frontend.fused import FusedFrontends
  21. from funasr.models.frontend.s3prl import S3prlFrontend
  22. from funasr.models.frontend.wav_frontend import WavFrontend
  23. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  24. from funasr.models.frontend.windowing import SlidingWindow
  25. from funasr.models.specaug.specaug import SpecAug
  26. from funasr.models.specaug.specaug import SpecAugLFR
  27. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  28. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  29. from funasr.torch_utils.initialize import initialize
  30. from funasr.train.class_choices import ClassChoices
  31. frontend_choices = ClassChoices(
  32. name="frontend",
  33. classes=dict(
  34. default=DefaultFrontend,
  35. sliding_window=SlidingWindow,
  36. s3prl=S3prlFrontend,
  37. fused=FusedFrontends,
  38. wav_frontend=WavFrontend,
  39. wav_frontend_mel23=WavFrontendMel23,
  40. ),
  41. default="default",
  42. )
  43. specaug_choices = ClassChoices(
  44. name="specaug",
  45. classes=dict(
  46. specaug=SpecAug,
  47. specaug_lfr=SpecAugLFR,
  48. ),
  49. default=None,
  50. optional=True,
  51. )
  52. normalize_choices = ClassChoices(
  53. "normalize",
  54. classes=dict(
  55. global_mvn=GlobalMVN,
  56. utterance_mvn=UtteranceMVN,
  57. ),
  58. default=None,
  59. optional=True,
  60. )
  61. label_aggregator_choices = ClassChoices(
  62. "label_aggregator",
  63. classes=dict(
  64. label_aggregator=LabelAggregate
  65. ),
  66. default=None,
  67. optional=True,
  68. )
  69. model_choices = ClassChoices(
  70. "model",
  71. classes=dict(
  72. sond=DiarSondModel,
  73. eend_ola=DiarEENDOLAModel,
  74. ),
  75. default="sond",
  76. )
  77. encoder_choices = ClassChoices(
  78. "encoder",
  79. classes=dict(
  80. conformer=ConformerEncoder,
  81. transformer=TransformerEncoder,
  82. rnn=RNNEncoder,
  83. sanm=SANMEncoder,
  84. san=SelfAttentionEncoder,
  85. fsmn=FsmnEncoder,
  86. conv=ConvEncoder,
  87. resnet34=ResNet34Diar,
  88. resnet34_sp_l2reg=ResNet34SpL2RegDiar,
  89. sanm_chunk_opt=SANMEncoderChunkOpt,
  90. data2vec_encoder=Data2VecEncoder,
  91. ecapa_tdnn=ECAPA_TDNN,
  92. eend_ola_transformer=EENDOLATransformerEncoder,
  93. ),
  94. default="resnet34",
  95. )
  96. speaker_encoder_choices = ClassChoices(
  97. "speaker_encoder",
  98. classes=dict(
  99. conformer=ConformerEncoder,
  100. transformer=TransformerEncoder,
  101. rnn=RNNEncoder,
  102. sanm=SANMEncoder,
  103. san=SelfAttentionEncoder,
  104. fsmn=FsmnEncoder,
  105. conv=ConvEncoder,
  106. sanm_chunk_opt=SANMEncoderChunkOpt,
  107. data2vec_encoder=Data2VecEncoder,
  108. ),
  109. default=None,
  110. optional=True
  111. )
  112. cd_scorer_choices = ClassChoices(
  113. "cd_scorer",
  114. classes=dict(
  115. san=SelfAttentionEncoder,
  116. ),
  117. default=None,
  118. optional=True,
  119. )
  120. ci_scorer_choices = ClassChoices(
  121. "ci_scorer",
  122. classes=dict(
  123. dot=DotScorer,
  124. cosine=CosScorer,
  125. conv=ConvEncoder,
  126. ),
  127. type_check=torch.nn.Module,
  128. default=None,
  129. optional=True,
  130. )
  131. # decoder is used for output (e.g. post_net in SOND)
  132. decoder_choices = ClassChoices(
  133. "decoder",
  134. classes=dict(
  135. rnn=RNNEncoder,
  136. fsmn=FsmnEncoder,
  137. ),
  138. type_check=torch.nn.Module,
  139. default="fsmn",
  140. )
  141. # encoder_decoder_attractor is used for EEND-OLA
  142. encoder_decoder_attractor_choices = ClassChoices(
  143. "encoder_decoder_attractor",
  144. classes=dict(
  145. eda=EncoderDecoderAttractor,
  146. ),
  147. type_check=torch.nn.Module,
  148. default="eda",
  149. )
  150. class_choices_list = [
  151. # --frontend and --frontend_conf
  152. frontend_choices,
  153. # --specaug and --specaug_conf
  154. specaug_choices,
  155. # --normalize and --normalize_conf
  156. normalize_choices,
  157. # --label_aggregator and --label_aggregator_conf
  158. label_aggregator_choices,
  159. # --model and --model_conf
  160. model_choices,
  161. # --encoder and --encoder_conf
  162. encoder_choices,
  163. # --speaker_encoder and --speaker_encoder_conf
  164. speaker_encoder_choices,
  165. # --cd_scorer and cd_scorer_conf
  166. cd_scorer_choices,
  167. # --ci_scorer and ci_scorer_conf
  168. ci_scorer_choices,
  169. # --decoder and --decoder_conf
  170. decoder_choices,
  171. # --eda and --eda_conf
  172. encoder_decoder_attractor_choices,
  173. ]
  174. def build_diar_model(args):
  175. # token_list
  176. if args.token_list is not None:
  177. with open(args.token_list) as f:
  178. token_list = [line.rstrip() for line in f]
  179. args.token_list = list(token_list)
  180. vocab_size = len(token_list)
  181. logging.info(f"Vocabulary size: {vocab_size}")
  182. else:
  183. vocab_size = None
  184. # frontend
  185. if args.input_size is None:
  186. frontend_class = frontend_choices.get_class(args.frontend)
  187. if args.frontend == 'wav_frontend':
  188. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  189. else:
  190. frontend = frontend_class(**args.frontend_conf)
  191. input_size = frontend.output_size()
  192. else:
  193. args.frontend = None
  194. args.frontend_conf = {}
  195. frontend = None
  196. input_size = args.input_size
  197. # encoder
  198. encoder_class = encoder_choices.get_class(args.encoder)
  199. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  200. if args.model_name == "sond":
  201. # data augmentation for spectrogram
  202. if args.specaug is not None:
  203. specaug_class = specaug_choices.get_class(args.specaug)
  204. specaug = specaug_class(**args.specaug_conf)
  205. else:
  206. specaug = None
  207. # normalization layer
  208. if args.normalize is not None:
  209. normalize_class = normalize_choices.get_class(args.normalize)
  210. normalize = normalize_class(**args.normalize_conf)
  211. else:
  212. normalize = None
  213. # speaker encoder
  214. if getattr(args, "speaker_encoder", None) is not None:
  215. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  216. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  217. else:
  218. speaker_encoder = None
  219. # ci scorer
  220. if getattr(args, "ci_scorer", None) is not None:
  221. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  222. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  223. else:
  224. ci_scorer = None
  225. # cd scorer
  226. if getattr(args, "cd_scorer", None) is not None:
  227. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  228. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  229. else:
  230. cd_scorer = None
  231. # decoder
  232. decoder_class = decoder_choices.get_class(args.decoder)
  233. decoder = decoder_class(
  234. vocab_size=vocab_size,
  235. encoder_output_size=encoder.output_size(),
  236. **args.decoder_conf,
  237. )
  238. # logger aggregator
  239. if getattr(args, "label_aggregator", None) is not None:
  240. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  241. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  242. else:
  243. label_aggregator = None
  244. model_class = model_choices.get_class(args.model)
  245. model = model_class(
  246. vocab_size=vocab_size,
  247. frontend=frontend,
  248. specaug=specaug,
  249. normalize=normalize,
  250. label_aggregator=label_aggregator,
  251. encoder=encoder,
  252. speaker_encoder=speaker_encoder,
  253. ci_scorer=ci_scorer,
  254. cd_scorer=cd_scorer,
  255. decoder=decoder,
  256. token_list=token_list,
  257. **args.model_conf,
  258. )
  259. elif args.model_name == "eend_ola":
  260. # encoder-decoder attractor
  261. encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
  262. encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
  263. # 9. Build model
  264. model_class = model_choices.get_class(args.model)
  265. model = model_class(
  266. frontend=frontend,
  267. encoder=encoder,
  268. encoder_decoder_attractor=encoder_decoder_attractor,
  269. **args.model_conf,
  270. )
  271. else:
  272. raise NotImplementedError("Not supported model: {}".format(args.model))
  273. # 10. Initialize
  274. if args.init is not None:
  275. initialize(model, args.init)
  276. return model