build_asr_model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. import logging
  2. from funasr.layers.global_mvn import GlobalMVN
  3. from funasr.layers.utterance_mvn import UtteranceMVN
  4. from funasr.models.ctc import CTC
  5. from funasr.models.decoder.abs_decoder import AbsDecoder
  6. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  7. from funasr.models.decoder.rnn_decoder import RNNDecoder
  8. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  9. from funasr.models.decoder.transformer_decoder import (
  10. DynamicConvolution2DTransformerDecoder, # noqa: H301
  11. )
  12. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  13. from funasr.models.decoder.transformer_decoder import (
  14. LightweightConvolution2DTransformerDecoder, # noqa: H301
  15. )
  16. from funasr.models.decoder.transformer_decoder import (
  17. LightweightConvolutionTransformerDecoder, # noqa: H301
  18. )
  19. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  20. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  21. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  22. from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
  23. from funasr.models.e2e_asr import ASRModel
  24. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  25. from funasr.models.e2e_asr_mfcca import MFCCA
  26. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  27. from funasr.models.e2e_asr_bat import BATModel
  28. from funasr.models.e2e_sa_asr import SAASRModel
  29. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  30. from funasr.models.e2e_tp import TimestampPredictor
  31. from funasr.models.e2e_uni_asr import UniASR
  32. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  33. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  34. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  35. from funasr.models.encoder.resnet34_encoder import ResNet34Diar
  36. from funasr.models.encoder.rnn_encoder import RNNEncoder
  37. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  38. from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
  39. from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
  40. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  41. from funasr.models.frontend.default import DefaultFrontend
  42. from funasr.models.frontend.default import MultiChannelFrontend
  43. from funasr.models.frontend.fused import FusedFrontends
  44. from funasr.models.frontend.s3prl import S3prlFrontend
  45. from funasr.models.frontend.wav_frontend import WavFrontend
  46. from funasr.models.frontend.windowing import SlidingWindow
  47. from funasr.models.joint_net.joint_network import JointNetwork
  48. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
  49. from funasr.models.specaug.specaug import SpecAug
  50. from funasr.models.specaug.specaug import SpecAugLFR
  51. from funasr.modules.subsampling import Conv1dSubsampling
  52. from funasr.torch_utils.initialize import initialize
  53. from funasr.train.class_choices import ClassChoices
  54. frontend_choices = ClassChoices(
  55. name="frontend",
  56. classes=dict(
  57. default=DefaultFrontend,
  58. sliding_window=SlidingWindow,
  59. s3prl=S3prlFrontend,
  60. fused=FusedFrontends,
  61. wav_frontend=WavFrontend,
  62. multichannelfrontend=MultiChannelFrontend,
  63. ),
  64. default="default",
  65. )
  66. specaug_choices = ClassChoices(
  67. name="specaug",
  68. classes=dict(
  69. specaug=SpecAug,
  70. specaug_lfr=SpecAugLFR,
  71. ),
  72. default=None,
  73. optional=True,
  74. )
  75. normalize_choices = ClassChoices(
  76. "normalize",
  77. classes=dict(
  78. global_mvn=GlobalMVN,
  79. utterance_mvn=UtteranceMVN,
  80. ),
  81. default=None,
  82. optional=True,
  83. )
  84. model_choices = ClassChoices(
  85. "model",
  86. classes=dict(
  87. asr=ASRModel,
  88. uniasr=UniASR,
  89. paraformer=Paraformer,
  90. paraformer_online=ParaformerOnline,
  91. paraformer_bert=ParaformerBert,
  92. bicif_paraformer=BiCifParaformer,
  93. contextual_paraformer=ContextualParaformer,
  94. neatcontextual_paraformer=NeatContextualParaformer,
  95. mfcca=MFCCA,
  96. timestamp_prediction=TimestampPredictor,
  97. rnnt=TransducerModel,
  98. rnnt_unified=UnifiedTransducerModel,
  99. sa_asr=SAASRModel,
  100. bat=BATModel,
  101. ),
  102. default="asr",
  103. )
  104. encoder_choices = ClassChoices(
  105. "encoder",
  106. classes=dict(
  107. conformer=ConformerEncoder,
  108. transformer=TransformerEncoder,
  109. rnn=RNNEncoder,
  110. sanm=SANMEncoder,
  111. sanm_chunk_opt=SANMEncoderChunkOpt,
  112. data2vec_encoder=Data2VecEncoder,
  113. branchformer=BranchformerEncoder,
  114. e_branchformer=EBranchformerEncoder,
  115. mfcca_enc=MFCCAEncoder,
  116. chunk_conformer=ConformerChunkEncoder,
  117. ),
  118. default="rnn",
  119. )
  120. asr_encoder_choices = ClassChoices(
  121. "asr_encoder",
  122. classes=dict(
  123. conformer=ConformerEncoder,
  124. transformer=TransformerEncoder,
  125. rnn=RNNEncoder,
  126. sanm=SANMEncoder,
  127. sanm_chunk_opt=SANMEncoderChunkOpt,
  128. data2vec_encoder=Data2VecEncoder,
  129. mfcca_enc=MFCCAEncoder,
  130. ),
  131. default="rnn",
  132. )
  133. spk_encoder_choices = ClassChoices(
  134. "spk_encoder",
  135. classes=dict(
  136. resnet34_diar=ResNet34Diar,
  137. ),
  138. default="resnet34_diar",
  139. )
  140. encoder_choices2 = ClassChoices(
  141. "encoder2",
  142. classes=dict(
  143. conformer=ConformerEncoder,
  144. transformer=TransformerEncoder,
  145. rnn=RNNEncoder,
  146. sanm=SANMEncoder,
  147. sanm_chunk_opt=SANMEncoderChunkOpt,
  148. ),
  149. default="rnn",
  150. )
  151. decoder_choices = ClassChoices(
  152. "decoder",
  153. classes=dict(
  154. transformer=TransformerDecoder,
  155. lightweight_conv=LightweightConvolutionTransformerDecoder,
  156. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  157. dynamic_conv=DynamicConvolutionTransformerDecoder,
  158. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  159. rnn=RNNDecoder,
  160. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  161. paraformer_decoder_sanm=ParaformerSANMDecoder,
  162. paraformer_decoder_san=ParaformerDecoderSAN,
  163. contextual_paraformer_decoder=ContextualParaformerDecoder,
  164. sa_decoder=SAAsrTransformerDecoder,
  165. ),
  166. default="rnn",
  167. )
  168. decoder_choices2 = ClassChoices(
  169. "decoder2",
  170. classes=dict(
  171. transformer=TransformerDecoder,
  172. lightweight_conv=LightweightConvolutionTransformerDecoder,
  173. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  174. dynamic_conv=DynamicConvolutionTransformerDecoder,
  175. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  176. rnn=RNNDecoder,
  177. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  178. paraformer_decoder_sanm=ParaformerSANMDecoder,
  179. ),
  180. type_check=AbsDecoder,
  181. default="rnn",
  182. )
  183. predictor_choices = ClassChoices(
  184. name="predictor",
  185. classes=dict(
  186. cif_predictor=CifPredictor,
  187. ctc_predictor=None,
  188. cif_predictor_v2=CifPredictorV2,
  189. cif_predictor_v3=CifPredictorV3,
  190. bat_predictor=BATPredictor,
  191. ),
  192. default="cif_predictor",
  193. optional=True,
  194. )
  195. predictor_choices2 = ClassChoices(
  196. name="predictor2",
  197. classes=dict(
  198. cif_predictor=CifPredictor,
  199. ctc_predictor=None,
  200. cif_predictor_v2=CifPredictorV2,
  201. ),
  202. default="cif_predictor",
  203. optional=True,
  204. )
  205. stride_conv_choices = ClassChoices(
  206. name="stride_conv",
  207. classes=dict(
  208. stride_conv1d=Conv1dSubsampling
  209. ),
  210. default="stride_conv1d",
  211. optional=True,
  212. )
  213. rnnt_decoder_choices = ClassChoices(
  214. name="rnnt_decoder",
  215. classes=dict(
  216. rnnt=RNNTDecoder,
  217. ),
  218. default="rnnt",
  219. optional=True,
  220. )
  221. joint_network_choices = ClassChoices(
  222. name="joint_network",
  223. classes=dict(
  224. joint_network=JointNetwork,
  225. ),
  226. default="joint_network",
  227. optional=True,
  228. )
  229. class_choices_list = [
  230. # --frontend and --frontend_conf
  231. frontend_choices,
  232. # --specaug and --specaug_conf
  233. specaug_choices,
  234. # --normalize and --normalize_conf
  235. normalize_choices,
  236. # --model and --model_conf
  237. model_choices,
  238. # --encoder and --encoder_conf
  239. encoder_choices,
  240. # --decoder and --decoder_conf
  241. decoder_choices,
  242. # --predictor and --predictor_conf
  243. predictor_choices,
  244. # --encoder2 and --encoder2_conf
  245. encoder_choices2,
  246. # --decoder2 and --decoder2_conf
  247. decoder_choices2,
  248. # --predictor2 and --predictor2_conf
  249. predictor_choices2,
  250. # --stride_conv and --stride_conv_conf
  251. stride_conv_choices,
  252. # --rnnt_decoder and --rnnt_decoder_conf
  253. rnnt_decoder_choices,
  254. # --joint_network and --joint_network_conf
  255. joint_network_choices,
  256. # --asr_encoder and --asr_encoder_conf
  257. asr_encoder_choices,
  258. # --spk_encoder and --spk_encoder_conf
  259. spk_encoder_choices,
  260. ]
  261. def build_asr_model(args):
  262. # token_list
  263. if isinstance(args.token_list, str):
  264. with open(args.token_list, encoding="utf-8") as f:
  265. token_list = [line.rstrip() for line in f]
  266. args.token_list = list(token_list)
  267. vocab_size = len(token_list)
  268. logging.info(f"Vocabulary size: {vocab_size}")
  269. elif isinstance(args.token_list, (tuple, list)):
  270. token_list = list(args.token_list)
  271. vocab_size = len(token_list)
  272. logging.info(f"Vocabulary size: {vocab_size}")
  273. else:
  274. token_list = None
  275. vocab_size = None
  276. # frontend
  277. if hasattr(args, "input_size") and args.input_size is None:
  278. frontend_class = frontend_choices.get_class(args.frontend)
  279. if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
  280. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  281. else:
  282. frontend = frontend_class(**args.frontend_conf)
  283. input_size = frontend.output_size()
  284. else:
  285. args.frontend = None
  286. args.frontend_conf = {}
  287. frontend = None
  288. input_size = args.input_size if hasattr(args, "input_size") else None
  289. # data augmentation for spectrogram
  290. if args.specaug is not None:
  291. specaug_class = specaug_choices.get_class(args.specaug)
  292. specaug = specaug_class(**args.specaug_conf)
  293. else:
  294. specaug = None
  295. # normalization layer
  296. if args.normalize is not None:
  297. normalize_class = normalize_choices.get_class(args.normalize)
  298. if args.model == "mfcca":
  299. normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
  300. else:
  301. normalize = normalize_class(**args.normalize_conf)
  302. else:
  303. normalize = None
  304. # encoder
  305. encoder_class = encoder_choices.get_class(args.encoder)
  306. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  307. # decoder
  308. if hasattr(args, "decoder") and args.decoder is not None:
  309. decoder_class = decoder_choices.get_class(args.decoder)
  310. decoder = decoder_class(
  311. vocab_size=vocab_size,
  312. encoder_output_size=encoder.output_size(),
  313. **args.decoder_conf,
  314. )
  315. else:
  316. decoder = None
  317. # ctc
  318. ctc = CTC(
  319. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  320. )
  321. if args.model in ["asr", "mfcca"]:
  322. model_class = model_choices.get_class(args.model)
  323. model = model_class(
  324. vocab_size=vocab_size,
  325. frontend=frontend,
  326. specaug=specaug,
  327. normalize=normalize,
  328. encoder=encoder,
  329. decoder=decoder,
  330. ctc=ctc,
  331. token_list=token_list,
  332. **args.model_conf,
  333. )
  334. elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
  335. "contextual_paraformer", "neatcontextual_paraformer"]:
  336. # predictor
  337. predictor_class = predictor_choices.get_class(args.predictor)
  338. predictor = predictor_class(**args.predictor_conf)
  339. model_class = model_choices.get_class(args.model)
  340. model = model_class(
  341. vocab_size=vocab_size,
  342. frontend=frontend,
  343. specaug=specaug,
  344. normalize=normalize,
  345. encoder=encoder,
  346. decoder=decoder,
  347. ctc=ctc,
  348. token_list=token_list,
  349. predictor=predictor,
  350. **args.model_conf,
  351. )
  352. elif args.model == "uniasr":
  353. # stride_conv
  354. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  355. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
  356. odim=input_size + encoder.output_size())
  357. stride_conv_output_size = stride_conv.output_size()
  358. # encoder2
  359. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  360. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  361. # decoder2
  362. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  363. decoder2 = decoder_class2(
  364. vocab_size=vocab_size,
  365. encoder_output_size=encoder2.output_size(),
  366. **args.decoder2_conf,
  367. )
  368. # ctc2
  369. ctc2 = CTC(
  370. odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
  371. )
  372. # predictor
  373. predictor_class = predictor_choices.get_class(args.predictor)
  374. predictor = predictor_class(**args.predictor_conf)
  375. # predictor2
  376. predictor_class = predictor_choices2.get_class(args.predictor2)
  377. predictor2 = predictor_class(**args.predictor2_conf)
  378. model_class = model_choices.get_class(args.model)
  379. model = model_class(
  380. vocab_size=vocab_size,
  381. frontend=frontend,
  382. specaug=specaug,
  383. normalize=normalize,
  384. encoder=encoder,
  385. decoder=decoder,
  386. ctc=ctc,
  387. token_list=token_list,
  388. predictor=predictor,
  389. ctc2=ctc2,
  390. encoder2=encoder2,
  391. decoder2=decoder2,
  392. predictor2=predictor2,
  393. stride_conv=stride_conv,
  394. **args.model_conf,
  395. )
  396. elif args.model == "timestamp_prediction":
  397. # predictor
  398. predictor_class = predictor_choices.get_class(args.predictor)
  399. predictor = predictor_class(**args.predictor_conf)
  400. model_class = model_choices.get_class(args.model)
  401. model = model_class(
  402. frontend=frontend,
  403. encoder=encoder,
  404. predictor=predictor,
  405. token_list=token_list,
  406. **args.model_conf,
  407. )
  408. elif args.model == "rnnt" or args.model == "rnnt_unified":
  409. # 5. Decoder
  410. encoder_output_size = encoder.output_size()
  411. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  412. decoder = rnnt_decoder_class(
  413. vocab_size,
  414. **args.rnnt_decoder_conf,
  415. )
  416. decoder_output_size = decoder.output_size
  417. if getattr(args, "decoder", None) is not None:
  418. att_decoder_class = decoder_choices.get_class(args.decoder)
  419. att_decoder = att_decoder_class(
  420. vocab_size=vocab_size,
  421. encoder_output_size=encoder_output_size,
  422. **args.decoder_conf,
  423. )
  424. else:
  425. att_decoder = None
  426. # 6. Joint Network
  427. joint_network = JointNetwork(
  428. vocab_size,
  429. encoder_output_size,
  430. decoder_output_size,
  431. **args.joint_network_conf,
  432. )
  433. model_class = model_choices.get_class(args.model)
  434. # 7. Build model
  435. model = model_class(
  436. vocab_size=vocab_size,
  437. token_list=token_list,
  438. frontend=frontend,
  439. specaug=specaug,
  440. normalize=normalize,
  441. encoder=encoder,
  442. decoder=decoder,
  443. att_decoder=att_decoder,
  444. joint_network=joint_network,
  445. **args.model_conf,
  446. )
  447. elif args.model == "bat":
  448. # 5. Decoder
  449. encoder_output_size = encoder.output_size()
  450. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  451. decoder = rnnt_decoder_class(
  452. vocab_size,
  453. **args.rnnt_decoder_conf,
  454. )
  455. decoder_output_size = decoder.output_size
  456. if getattr(args, "decoder", None) is not None:
  457. att_decoder_class = decoder_choices.get_class(args.decoder)
  458. att_decoder = att_decoder_class(
  459. vocab_size=vocab_size,
  460. encoder_output_size=encoder_output_size,
  461. **args.decoder_conf,
  462. )
  463. else:
  464. att_decoder = None
  465. # 6. Joint Network
  466. joint_network = JointNetwork(
  467. vocab_size,
  468. encoder_output_size,
  469. decoder_output_size,
  470. **args.joint_network_conf,
  471. )
  472. predictor_class = predictor_choices.get_class(args.predictor)
  473. predictor = predictor_class(**args.predictor_conf)
  474. model_class = model_choices.get_class(args.model)
  475. # 7. Build model
  476. model = model_class(
  477. vocab_size=vocab_size,
  478. token_list=token_list,
  479. frontend=frontend,
  480. specaug=specaug,
  481. normalize=normalize,
  482. encoder=encoder,
  483. decoder=decoder,
  484. att_decoder=att_decoder,
  485. joint_network=joint_network,
  486. predictor=predictor,
  487. **args.model_conf,
  488. )
  489. elif args.model == "sa_asr":
  490. asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
  491. asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
  492. spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
  493. spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
  494. decoder = decoder_class(
  495. vocab_size=vocab_size,
  496. encoder_output_size=asr_encoder.output_size(),
  497. **args.decoder_conf,
  498. )
  499. ctc = CTC(
  500. odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
  501. )
  502. model_class = model_choices.get_class(args.model)
  503. model = model_class(
  504. vocab_size=vocab_size,
  505. frontend=frontend,
  506. specaug=specaug,
  507. normalize=normalize,
  508. asr_encoder=asr_encoder,
  509. spk_encoder=spk_encoder,
  510. decoder=decoder,
  511. ctc=ctc,
  512. token_list=token_list,
  513. **args.model_conf,
  514. )
  515. else:
  516. raise NotImplementedError("Not supported model: {}".format(args.model))
  517. # initialize
  518. if args.init is not None:
  519. initialize(model, args.init)
  520. return model