build_asr_model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import logging
  2. from funasr.layers.global_mvn import GlobalMVN
  3. from funasr.layers.utterance_mvn import UtteranceMVN
  4. from funasr.models.ctc import CTC
  5. from funasr.models.decoder.abs_decoder import AbsDecoder
  6. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  7. from funasr.models.decoder.rnn_decoder import RNNDecoder
  8. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  9. from funasr.models.decoder.transformer_decoder import (
  10. DynamicConvolution2DTransformerDecoder, # noqa: H301
  11. )
  12. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  13. from funasr.models.decoder.transformer_decoder import (
  14. LightweightConvolution2DTransformerDecoder, # noqa: H301
  15. )
  16. from funasr.models.decoder.transformer_decoder import (
  17. LightweightConvolutionTransformerDecoder, # noqa: H301
  18. )
  19. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  20. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  21. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  22. from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
  23. from funasr.models.e2e_asr import ASRModel
  24. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  25. from funasr.models.e2e_asr_mfcca import MFCCA
  26. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  27. from funasr.models.e2e_sa_asr import SAASRModel
  28. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  29. from funasr.models.e2e_tp import TimestampPredictor
  30. from funasr.models.e2e_uni_asr import UniASR
  31. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  32. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  33. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  34. from funasr.models.encoder.resnet34_encoder import ResNet34Diar
  35. from funasr.models.encoder.rnn_encoder import RNNEncoder
  36. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  37. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  38. from funasr.models.frontend.default import DefaultFrontend
  39. from funasr.models.frontend.default import MultiChannelFrontend
  40. from funasr.models.frontend.fused import FusedFrontends
  41. from funasr.models.frontend.s3prl import S3prlFrontend
  42. from funasr.models.frontend.wav_frontend import WavFrontend
  43. from funasr.models.frontend.windowing import SlidingWindow
  44. from funasr.models.joint_net.joint_network import JointNetwork
  45. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  46. from funasr.models.specaug.specaug import SpecAug
  47. from funasr.models.specaug.specaug import SpecAugLFR
  48. from funasr.modules.subsampling import Conv1dSubsampling
  49. from funasr.torch_utils.initialize import initialize
  50. from funasr.train.class_choices import ClassChoices
  51. frontend_choices = ClassChoices(
  52. name="frontend",
  53. classes=dict(
  54. default=DefaultFrontend,
  55. sliding_window=SlidingWindow,
  56. s3prl=S3prlFrontend,
  57. fused=FusedFrontends,
  58. wav_frontend=WavFrontend,
  59. multichannelfrontend=MultiChannelFrontend,
  60. ),
  61. default="default",
  62. )
  63. specaug_choices = ClassChoices(
  64. name="specaug",
  65. classes=dict(
  66. specaug=SpecAug,
  67. specaug_lfr=SpecAugLFR,
  68. ),
  69. default=None,
  70. optional=True,
  71. )
  72. normalize_choices = ClassChoices(
  73. "normalize",
  74. classes=dict(
  75. global_mvn=GlobalMVN,
  76. utterance_mvn=UtteranceMVN,
  77. ),
  78. default=None,
  79. optional=True,
  80. )
  81. model_choices = ClassChoices(
  82. "model",
  83. classes=dict(
  84. asr=ASRModel,
  85. uniasr=UniASR,
  86. paraformer=Paraformer,
  87. paraformer_online=ParaformerOnline,
  88. paraformer_bert=ParaformerBert,
  89. bicif_paraformer=BiCifParaformer,
  90. contextual_paraformer=ContextualParaformer,
  91. neatcontextual_paraformer=NeatContextualParaformer,
  92. mfcca=MFCCA,
  93. timestamp_prediction=TimestampPredictor,
  94. rnnt=TransducerModel,
  95. rnnt_unified=UnifiedTransducerModel,
  96. sa_asr=SAASRModel,
  97. ),
  98. default="asr",
  99. )
  100. encoder_choices = ClassChoices(
  101. "encoder",
  102. classes=dict(
  103. conformer=ConformerEncoder,
  104. transformer=TransformerEncoder,
  105. rnn=RNNEncoder,
  106. sanm=SANMEncoder,
  107. sanm_chunk_opt=SANMEncoderChunkOpt,
  108. data2vec_encoder=Data2VecEncoder,
  109. mfcca_enc=MFCCAEncoder,
  110. chunk_conformer=ConformerChunkEncoder,
  111. ),
  112. default="rnn",
  113. )
  114. asr_encoder_choices = ClassChoices(
  115. "asr_encoder",
  116. classes=dict(
  117. conformer=ConformerEncoder,
  118. transformer=TransformerEncoder,
  119. rnn=RNNEncoder,
  120. sanm=SANMEncoder,
  121. sanm_chunk_opt=SANMEncoderChunkOpt,
  122. data2vec_encoder=Data2VecEncoder,
  123. mfcca_enc=MFCCAEncoder,
  124. ),
  125. default="rnn",
  126. )
  127. spk_encoder_choices = ClassChoices(
  128. "spk_encoder",
  129. classes=dict(
  130. resnet34_diar=ResNet34Diar,
  131. ),
  132. default="resnet34_diar",
  133. )
  134. encoder_choices2 = ClassChoices(
  135. "encoder2",
  136. classes=dict(
  137. conformer=ConformerEncoder,
  138. transformer=TransformerEncoder,
  139. rnn=RNNEncoder,
  140. sanm=SANMEncoder,
  141. sanm_chunk_opt=SANMEncoderChunkOpt,
  142. ),
  143. default="rnn",
  144. )
  145. decoder_choices = ClassChoices(
  146. "decoder",
  147. classes=dict(
  148. transformer=TransformerDecoder,
  149. lightweight_conv=LightweightConvolutionTransformerDecoder,
  150. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  151. dynamic_conv=DynamicConvolutionTransformerDecoder,
  152. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  153. rnn=RNNDecoder,
  154. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  155. paraformer_decoder_sanm=ParaformerSANMDecoder,
  156. paraformer_decoder_san=ParaformerDecoderSAN,
  157. contextual_paraformer_decoder=ContextualParaformerDecoder,
  158. sa_decoder=SAAsrTransformerDecoder,
  159. ),
  160. default="rnn",
  161. )
  162. decoder_choices2 = ClassChoices(
  163. "decoder2",
  164. classes=dict(
  165. transformer=TransformerDecoder,
  166. lightweight_conv=LightweightConvolutionTransformerDecoder,
  167. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  168. dynamic_conv=DynamicConvolutionTransformerDecoder,
  169. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  170. rnn=RNNDecoder,
  171. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  172. paraformer_decoder_sanm=ParaformerSANMDecoder,
  173. ),
  174. type_check=AbsDecoder,
  175. default="rnn",
  176. )
  177. predictor_choices = ClassChoices(
  178. name="predictor",
  179. classes=dict(
  180. cif_predictor=CifPredictor,
  181. ctc_predictor=None,
  182. cif_predictor_v2=CifPredictorV2,
  183. cif_predictor_v3=CifPredictorV3,
  184. ),
  185. default="cif_predictor",
  186. optional=True,
  187. )
  188. predictor_choices2 = ClassChoices(
  189. name="predictor2",
  190. classes=dict(
  191. cif_predictor=CifPredictor,
  192. ctc_predictor=None,
  193. cif_predictor_v2=CifPredictorV2,
  194. ),
  195. default="cif_predictor",
  196. optional=True,
  197. )
  198. stride_conv_choices = ClassChoices(
  199. name="stride_conv",
  200. classes=dict(
  201. stride_conv1d=Conv1dSubsampling
  202. ),
  203. default="stride_conv1d",
  204. optional=True,
  205. )
  206. rnnt_decoder_choices = ClassChoices(
  207. name="rnnt_decoder",
  208. classes=dict(
  209. rnnt=RNNTDecoder,
  210. ),
  211. default="rnnt",
  212. optional=True,
  213. )
  214. joint_network_choices = ClassChoices(
  215. name="joint_network",
  216. classes=dict(
  217. joint_network=JointNetwork,
  218. ),
  219. default="joint_network",
  220. optional=True,
  221. )
  222. class_choices_list = [
  223. # --frontend and --frontend_conf
  224. frontend_choices,
  225. # --specaug and --specaug_conf
  226. specaug_choices,
  227. # --normalize and --normalize_conf
  228. normalize_choices,
  229. # --model and --model_conf
  230. model_choices,
  231. # --encoder and --encoder_conf
  232. encoder_choices,
  233. # --decoder and --decoder_conf
  234. decoder_choices,
  235. # --predictor and --predictor_conf
  236. predictor_choices,
  237. # --encoder2 and --encoder2_conf
  238. encoder_choices2,
  239. # --decoder2 and --decoder2_conf
  240. decoder_choices2,
  241. # --predictor2 and --predictor2_conf
  242. predictor_choices2,
  243. # --stride_conv and --stride_conv_conf
  244. stride_conv_choices,
  245. # --rnnt_decoder and --rnnt_decoder_conf
  246. rnnt_decoder_choices,
  247. # --joint_network and --joint_network_conf
  248. joint_network_choices,
  249. # --asr_encoder and --asr_encoder_conf
  250. asr_encoder_choices,
  251. # --spk_encoder and --spk_encoder_conf
  252. spk_encoder_choices,
  253. ]
  254. def build_asr_model(args):
  255. # token_list
  256. if isinstance(args.token_list, str):
  257. with open(args.token_list, encoding="utf-8") as f:
  258. token_list = [line.rstrip() for line in f]
  259. args.token_list = list(token_list)
  260. vocab_size = len(token_list)
  261. logging.info(f"Vocabulary size: {vocab_size}")
  262. elif isinstance(args.token_list, (tuple, list)):
  263. token_list = list(args.token_list)
  264. vocab_size = len(token_list)
  265. logging.info(f"Vocabulary size: {vocab_size}")
  266. else:
  267. token_list = None
  268. vocab_size = None
  269. # frontend
  270. if hasattr(args, "input_size") and args.input_size is None:
  271. frontend_class = frontend_choices.get_class(args.frontend)
  272. if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
  273. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  274. else:
  275. frontend = frontend_class(**args.frontend_conf)
  276. input_size = frontend.output_size()
  277. else:
  278. args.frontend = None
  279. args.frontend_conf = {}
  280. frontend = None
  281. input_size = args.input_size if hasattr(args, "input_size") else None
  282. # data augmentation for spectrogram
  283. if args.specaug is not None:
  284. specaug_class = specaug_choices.get_class(args.specaug)
  285. specaug = specaug_class(**args.specaug_conf)
  286. else:
  287. specaug = None
  288. # normalization layer
  289. if args.normalize is not None:
  290. normalize_class = normalize_choices.get_class(args.normalize)
  291. if args.model == "mfcca":
  292. normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
  293. else:
  294. normalize = normalize_class(**args.normalize_conf)
  295. else:
  296. normalize = None
  297. # encoder
  298. encoder_class = encoder_choices.get_class(args.encoder)
  299. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  300. # decoder
  301. decoder_class = decoder_choices.get_class(args.decoder)
  302. decoder = decoder_class(
  303. vocab_size=vocab_size,
  304. encoder_output_size=encoder.output_size(),
  305. **args.decoder_conf,
  306. )
  307. # ctc
  308. ctc = CTC(
  309. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  310. )
  311. if args.model in ["asr", "mfcca"]:
  312. model_class = model_choices.get_class(args.model)
  313. model = model_class(
  314. vocab_size=vocab_size,
  315. frontend=frontend,
  316. specaug=specaug,
  317. normalize=normalize,
  318. encoder=encoder,
  319. decoder=decoder,
  320. ctc=ctc,
  321. token_list=token_list,
  322. **args.model_conf,
  323. )
  324. elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
  325. "contextual_paraformer", "neatcontextual_paraformer"]:
  326. # predictor
  327. predictor_class = predictor_choices.get_class(args.predictor)
  328. predictor = predictor_class(**args.predictor_conf)
  329. model_class = model_choices.get_class(args.model)
  330. model = model_class(
  331. vocab_size=vocab_size,
  332. frontend=frontend,
  333. specaug=specaug,
  334. normalize=normalize,
  335. encoder=encoder,
  336. decoder=decoder,
  337. ctc=ctc,
  338. token_list=token_list,
  339. predictor=predictor,
  340. **args.model_conf,
  341. )
  342. elif args.model == "uniasr":
  343. # stride_conv
  344. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  345. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
  346. odim=input_size + encoder.output_size())
  347. stride_conv_output_size = stride_conv.output_size()
  348. # encoder2
  349. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  350. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  351. # decoder2
  352. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  353. decoder2 = decoder_class2(
  354. vocab_size=vocab_size,
  355. encoder_output_size=encoder2.output_size(),
  356. **args.decoder2_conf,
  357. )
  358. # ctc2
  359. ctc2 = CTC(
  360. odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
  361. )
  362. # predictor
  363. predictor_class = predictor_choices.get_class(args.predictor)
  364. predictor = predictor_class(**args.predictor_conf)
  365. # predictor2
  366. predictor_class = predictor_choices2.get_class(args.predictor2)
  367. predictor2 = predictor_class(**args.predictor2_conf)
  368. model_class = model_choices.get_class(args.model)
  369. model = model_class(
  370. vocab_size=vocab_size,
  371. frontend=frontend,
  372. specaug=specaug,
  373. normalize=normalize,
  374. encoder=encoder,
  375. decoder=decoder,
  376. ctc=ctc,
  377. token_list=token_list,
  378. predictor=predictor,
  379. ctc2=ctc2,
  380. encoder2=encoder2,
  381. decoder2=decoder2,
  382. predictor2=predictor2,
  383. stride_conv=stride_conv,
  384. **args.model_conf,
  385. )
  386. elif args.model == "timestamp_prediction":
  387. # predictor
  388. predictor_class = predictor_choices.get_class(args.predictor)
  389. predictor = predictor_class(**args.predictor_conf)
  390. model_class = model_choices.get_class(args.model)
  391. model = model_class(
  392. frontend=frontend,
  393. encoder=encoder,
  394. predictor=predictor,
  395. token_list=token_list,
  396. **args.model_conf,
  397. )
  398. elif args.model == "rnnt" or args.model == "rnnt_unified":
  399. # 5. Decoder
  400. encoder_output_size = encoder.output_size()
  401. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  402. decoder = rnnt_decoder_class(
  403. vocab_size,
  404. **args.rnnt_decoder_conf,
  405. )
  406. decoder_output_size = decoder.output_size
  407. if getattr(args, "decoder", None) is not None:
  408. att_decoder_class = decoder_choices.get_class(args.decoder)
  409. att_decoder = att_decoder_class(
  410. vocab_size=vocab_size,
  411. encoder_output_size=encoder_output_size,
  412. **args.decoder_conf,
  413. )
  414. else:
  415. att_decoder = None
  416. # 6. Joint Network
  417. joint_network = JointNetwork(
  418. vocab_size,
  419. encoder_output_size,
  420. decoder_output_size,
  421. **args.joint_network_conf,
  422. )
  423. model_class = model_choices.get_class(args.model)
  424. # 7. Build model
  425. model = model_class(
  426. vocab_size=vocab_size,
  427. token_list=token_list,
  428. frontend=frontend,
  429. specaug=specaug,
  430. normalize=normalize,
  431. encoder=encoder,
  432. decoder=decoder,
  433. att_decoder=att_decoder,
  434. joint_network=joint_network,
  435. **args.model_conf,
  436. )
  437. elif args.model == "sa_asr":
  438. asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
  439. asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
  440. spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
  441. spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
  442. decoder = decoder_class(
  443. vocab_size=vocab_size,
  444. encoder_output_size=asr_encoder.output_size(),
  445. **args.decoder_conf,
  446. )
  447. ctc = CTC(
  448. odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
  449. )
  450. model_class = model_choices.get_class(args.model)
  451. model = model_class(
  452. vocab_size=vocab_size,
  453. frontend=frontend,
  454. specaug=specaug,
  455. normalize=normalize,
  456. asr_encoder=asr_encoder,
  457. spk_encoder=spk_encoder,
  458. decoder=decoder,
  459. ctc=ctc,
  460. token_list=token_list,
  461. **args.model_conf,
  462. )
  463. else:
  464. raise NotImplementedError("Not supported model: {}".format(args.model))
  465. # initialize
  466. if args.init is not None:
  467. initialize(model, args.init)
  468. return model