build_diar_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import logging
  2. import torch
  3. from funasr.layers.global_mvn import GlobalMVN
  4. from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
  5. from funasr.layers.utterance_mvn import UtteranceMVN
  6. from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
  7. from funasr.models.e2e_diar_sond import DiarSondModel
  8. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  9. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  10. from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
  11. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  12. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  13. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  14. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  15. from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
  16. from funasr.models.encoder.rnn_encoder import RNNEncoder
  17. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  18. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  19. from funasr.models.frontend.default import DefaultFrontend
  20. from funasr.models.frontend.fused import FusedFrontends
  21. from funasr.models.frontend.s3prl import S3prlFrontend
  22. from funasr.models.frontend.wav_frontend import WavFrontend
  23. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  24. from funasr.models.frontend.windowing import SlidingWindow
  25. from funasr.models.specaug.specaug import SpecAug
  26. from funasr.models.specaug.specaug import SpecAugLFR
  27. from funasr.models.specaug.abs_profileaug import AbsProfileAug
  28. from funasr.models.specaug.profileaug import ProfileAug
  29. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  30. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  31. from funasr.torch_utils.initialize import initialize
  32. from funasr.train.class_choices import ClassChoices
  33. frontend_choices = ClassChoices(
  34. name="frontend",
  35. classes=dict(
  36. default=DefaultFrontend,
  37. sliding_window=SlidingWindow,
  38. s3prl=S3prlFrontend,
  39. fused=FusedFrontends,
  40. wav_frontend=WavFrontend,
  41. wav_frontend_mel23=WavFrontendMel23,
  42. ),
  43. default="default",
  44. )
  45. specaug_choices = ClassChoices(
  46. name="specaug",
  47. classes=dict(
  48. specaug=SpecAug,
  49. specaug_lfr=SpecAugLFR,
  50. ),
  51. default=None,
  52. optional=True,
  53. )
  54. profileaug_choices = ClassChoices(
  55. name="profileaug",
  56. classes=dict(
  57. profileaug=ProfileAug,
  58. ),
  59. type_check=AbsProfileAug,
  60. default=None,
  61. optional=True,
  62. )
  63. normalize_choices = ClassChoices(
  64. "normalize",
  65. classes=dict(
  66. global_mvn=GlobalMVN,
  67. utterance_mvn=UtteranceMVN,
  68. ),
  69. default=None,
  70. optional=True,
  71. )
  72. label_aggregator_choices = ClassChoices(
  73. "label_aggregator",
  74. classes=dict(
  75. label_aggregator=LabelAggregate,
  76. label_aggregator_max_pool=LabelAggregateMaxPooling,
  77. ),
  78. default=None,
  79. optional=True,
  80. )
  81. model_choices = ClassChoices(
  82. "model",
  83. classes=dict(
  84. sond=DiarSondModel,
  85. eend_ola=DiarEENDOLAModel,
  86. ),
  87. default="sond",
  88. )
  89. encoder_choices = ClassChoices(
  90. "encoder",
  91. classes=dict(
  92. conformer=ConformerEncoder,
  93. transformer=TransformerEncoder,
  94. rnn=RNNEncoder,
  95. sanm=SANMEncoder,
  96. san=SelfAttentionEncoder,
  97. fsmn=FsmnEncoder,
  98. conv=ConvEncoder,
  99. resnet34=ResNet34Diar,
  100. resnet34_sp_l2reg=ResNet34SpL2RegDiar,
  101. sanm_chunk_opt=SANMEncoderChunkOpt,
  102. data2vec_encoder=Data2VecEncoder,
  103. ecapa_tdnn=ECAPA_TDNN,
  104. eend_ola_transformer=EENDOLATransformerEncoder,
  105. ),
  106. default="resnet34",
  107. )
  108. speaker_encoder_choices = ClassChoices(
  109. "speaker_encoder",
  110. classes=dict(
  111. conformer=ConformerEncoder,
  112. transformer=TransformerEncoder,
  113. rnn=RNNEncoder,
  114. sanm=SANMEncoder,
  115. san=SelfAttentionEncoder,
  116. fsmn=FsmnEncoder,
  117. conv=ConvEncoder,
  118. sanm_chunk_opt=SANMEncoderChunkOpt,
  119. data2vec_encoder=Data2VecEncoder,
  120. ),
  121. default=None,
  122. optional=True
  123. )
  124. cd_scorer_choices = ClassChoices(
  125. "cd_scorer",
  126. classes=dict(
  127. san=SelfAttentionEncoder,
  128. ),
  129. default=None,
  130. optional=True,
  131. )
  132. ci_scorer_choices = ClassChoices(
  133. "ci_scorer",
  134. classes=dict(
  135. dot=DotScorer,
  136. cosine=CosScorer,
  137. conv=ConvEncoder,
  138. ),
  139. type_check=torch.nn.Module,
  140. default=None,
  141. optional=True,
  142. )
  143. # decoder is used for output (e.g. post_net in SOND)
  144. decoder_choices = ClassChoices(
  145. "decoder",
  146. classes=dict(
  147. rnn=RNNEncoder,
  148. fsmn=FsmnEncoder,
  149. ),
  150. type_check=torch.nn.Module,
  151. default="fsmn",
  152. )
  153. # encoder_decoder_attractor is used for EEND-OLA
  154. encoder_decoder_attractor_choices = ClassChoices(
  155. "encoder_decoder_attractor",
  156. classes=dict(
  157. eda=EncoderDecoderAttractor,
  158. ),
  159. type_check=torch.nn.Module,
  160. default="eda",
  161. )
  162. class_choices_list = [
  163. # --frontend and --frontend_conf
  164. frontend_choices,
  165. # --specaug and --specaug_conf
  166. specaug_choices,
  167. # --profileaug and --profileaug_conf
  168. profileaug_choices,
  169. # --normalize and --normalize_conf
  170. normalize_choices,
  171. # --label_aggregator and --label_aggregator_conf
  172. label_aggregator_choices,
  173. # --model and --model_conf
  174. model_choices,
  175. # --encoder and --encoder_conf
  176. encoder_choices,
  177. # --speaker_encoder and --speaker_encoder_conf
  178. speaker_encoder_choices,
  179. # --cd_scorer and cd_scorer_conf
  180. cd_scorer_choices,
  181. # --ci_scorer and ci_scorer_conf
  182. ci_scorer_choices,
  183. # --decoder and --decoder_conf
  184. decoder_choices,
  185. # --eda and --eda_conf
  186. encoder_decoder_attractor_choices,
  187. ]
  188. def build_diar_model(args):
  189. # token_list
  190. if args.token_list is not None:
  191. if isinstance(args.token_list, str):
  192. with open(args.token_list, encoding="utf-8") as f:
  193. token_list = [line.rstrip() for line in f]
  194. # Overwriting token_list to keep it as "portable".
  195. args.token_list = list(token_list)
  196. elif isinstance(args.token_list, (tuple, list)):
  197. token_list = list(args.token_list)
  198. else:
  199. raise RuntimeError("token_list must be str or list")
  200. vocab_size = len(token_list)
  201. logging.info(f"Vocabulary size: {vocab_size}")
  202. else:
  203. token_list = None
  204. vocab_size = None
  205. # frontend
  206. if args.input_size is None:
  207. frontend_class = frontend_choices.get_class(args.frontend)
  208. if args.frontend == 'wav_frontend':
  209. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  210. else:
  211. frontend = frontend_class(**args.frontend_conf)
  212. input_size = frontend.output_size()
  213. else:
  214. args.frontend = None
  215. args.frontend_conf = {}
  216. frontend = None
  217. input_size = args.input_size
  218. if args.model == "sond":
  219. # encoder
  220. encoder_class = encoder_choices.get_class(args.encoder)
  221. encoder = encoder_class(input_size=input_size ,**args.encoder_conf)
  222. # data augmentation for spectrogram
  223. if args.specaug is not None:
  224. specaug_class = specaug_choices.get_class(args.specaug)
  225. specaug = specaug_class(**args.specaug_conf)
  226. else:
  227. specaug = None
  228. # Data augmentation for Profiles
  229. if hasattr(args, "profileaug") and args.profileaug is not None:
  230. profileaug_class = profileaug_choices.get_class(args.profileaug)
  231. profileaug = profileaug_class(**args.profileaug_conf)
  232. else:
  233. profileaug = None
  234. # normalization layer
  235. if args.normalize is not None:
  236. normalize_class = normalize_choices.get_class(args.normalize)
  237. normalize = normalize_class(**args.normalize_conf)
  238. else:
  239. normalize = None
  240. # speaker encoder
  241. if getattr(args, "speaker_encoder", None) is not None:
  242. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  243. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  244. else:
  245. speaker_encoder = None
  246. # ci scorer
  247. if getattr(args, "ci_scorer", None) is not None:
  248. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  249. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  250. else:
  251. ci_scorer = None
  252. # cd scorer
  253. if getattr(args, "cd_scorer", None) is not None:
  254. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  255. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  256. else:
  257. cd_scorer = None
  258. # decoder
  259. decoder_class = decoder_choices.get_class(args.decoder)
  260. decoder = decoder_class(**args.decoder_conf)
  261. # logger aggregator
  262. if getattr(args, "label_aggregator", None) is not None:
  263. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  264. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  265. else:
  266. label_aggregator = None
  267. model_class = model_choices.get_class(args.model)
  268. model = model_class(
  269. vocab_size=vocab_size,
  270. frontend=frontend,
  271. specaug=specaug,
  272. profileaug=profileaug,
  273. normalize=normalize,
  274. label_aggregator=label_aggregator,
  275. encoder=encoder,
  276. speaker_encoder=speaker_encoder,
  277. ci_scorer=ci_scorer,
  278. cd_scorer=cd_scorer,
  279. decoder=decoder,
  280. token_list=token_list,
  281. **args.model_conf,
  282. )
  283. elif args.model == "eend_ola":
  284. # encoder
  285. encoder_class = encoder_choices.get_class(args.encoder)
  286. encoder = encoder_class(**args.encoder_conf)
  287. # encoder-decoder attractor
  288. encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
  289. encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
  290. # 9. Build model
  291. model_class = model_choices.get_class(args.model)
  292. model = model_class(
  293. frontend=frontend,
  294. encoder=encoder,
  295. encoder_decoder_attractor=encoder_decoder_attractor,
  296. **args.model_conf,
  297. )
  298. else:
  299. raise NotImplementedError("Not supported model: {}".format(args.model))
  300. # 10. Initialize
  301. if args.init is not None:
  302. initialize(model, args.init)
  303. return model