build_asr_model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import logging
  2. from funasr.layers.global_mvn import GlobalMVN
  3. from funasr.layers.utterance_mvn import UtteranceMVN
  4. from funasr.models.ctc import CTC
  5. from funasr.models.decoder.abs_decoder import AbsDecoder
  6. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  7. from funasr.models.decoder.rnn_decoder import RNNDecoder
  8. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  9. from funasr.models.decoder.transformer_decoder import (
  10. DynamicConvolution2DTransformerDecoder, # noqa: H301
  11. )
  12. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  13. from funasr.models.decoder.transformer_decoder import (
  14. LightweightConvolution2DTransformerDecoder, # noqa: H301
  15. )
  16. from funasr.models.decoder.transformer_decoder import (
  17. LightweightConvolutionTransformerDecoder, # noqa: H301
  18. )
  19. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  20. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  21. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  22. from funasr.models.joint_net.joint_network import JointNetwork
  23. from funasr.models.e2e_asr import ASRModel
  24. from funasr.models.e2e_asr_mfcca import MFCCA
  25. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  26. from funasr.models.e2e_tp import TimestampPredictor
  27. from funasr.models.e2e_uni_asr import UniASR
  28. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  29. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  30. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  31. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  32. from funasr.models.encoder.rnn_encoder import RNNEncoder
  33. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  34. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  35. from funasr.models.frontend.default import DefaultFrontend
  36. from funasr.models.frontend.default import MultiChannelFrontend
  37. from funasr.models.frontend.fused import FusedFrontends
  38. from funasr.models.frontend.s3prl import S3prlFrontend
  39. from funasr.models.frontend.wav_frontend import WavFrontend
  40. from funasr.models.frontend.windowing import SlidingWindow
  41. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  42. from funasr.models.specaug.specaug import SpecAug
  43. from funasr.models.specaug.specaug import SpecAugLFR
  44. from funasr.modules.subsampling import Conv1dSubsampling
  45. from funasr.torch_utils.initialize import initialize
  46. from funasr.train.class_choices import ClassChoices
  47. frontend_choices = ClassChoices(
  48. name="frontend",
  49. classes=dict(
  50. default=DefaultFrontend,
  51. sliding_window=SlidingWindow,
  52. s3prl=S3prlFrontend,
  53. fused=FusedFrontends,
  54. wav_frontend=WavFrontend,
  55. multichannelfrontend=MultiChannelFrontend,
  56. ),
  57. default="default",
  58. )
  59. specaug_choices = ClassChoices(
  60. name="specaug",
  61. classes=dict(
  62. specaug=SpecAug,
  63. specaug_lfr=SpecAugLFR,
  64. ),
  65. default=None,
  66. optional=True,
  67. )
  68. normalize_choices = ClassChoices(
  69. "normalize",
  70. classes=dict(
  71. global_mvn=GlobalMVN,
  72. utterance_mvn=UtteranceMVN,
  73. ),
  74. default=None,
  75. optional=True,
  76. )
  77. model_choices = ClassChoices(
  78. "model",
  79. classes=dict(
  80. asr=ASRModel,
  81. uniasr=UniASR,
  82. paraformer=Paraformer,
  83. paraformer_online=ParaformerOnline,
  84. paraformer_bert=ParaformerBert,
  85. bicif_paraformer=BiCifParaformer,
  86. contextual_paraformer=ContextualParaformer,
  87. mfcca=MFCCA,
  88. timestamp_prediction=TimestampPredictor,
  89. rnnt=TransducerModel,
  90. rnnt_unified=UnifiedTransducerModel,
  91. ),
  92. default="asr",
  93. )
  94. encoder_choices = ClassChoices(
  95. "encoder",
  96. classes=dict(
  97. conformer=ConformerEncoder,
  98. transformer=TransformerEncoder,
  99. rnn=RNNEncoder,
  100. sanm=SANMEncoder,
  101. sanm_chunk_opt=SANMEncoderChunkOpt,
  102. data2vec_encoder=Data2VecEncoder,
  103. mfcca_enc=MFCCAEncoder,
  104. chunk_conformer=ConformerChunkEncoder,
  105. ),
  106. default="rnn",
  107. )
  108. encoder_choices2 = ClassChoices(
  109. "encoder2",
  110. classes=dict(
  111. conformer=ConformerEncoder,
  112. transformer=TransformerEncoder,
  113. rnn=RNNEncoder,
  114. sanm=SANMEncoder,
  115. sanm_chunk_opt=SANMEncoderChunkOpt,
  116. ),
  117. default="rnn",
  118. )
  119. decoder_choices = ClassChoices(
  120. "decoder",
  121. classes=dict(
  122. transformer=TransformerDecoder,
  123. lightweight_conv=LightweightConvolutionTransformerDecoder,
  124. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  125. dynamic_conv=DynamicConvolutionTransformerDecoder,
  126. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  127. rnn=RNNDecoder,
  128. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  129. paraformer_decoder_sanm=ParaformerSANMDecoder,
  130. paraformer_decoder_san=ParaformerDecoderSAN,
  131. contextual_paraformer_decoder=ContextualParaformerDecoder,
  132. ),
  133. default="rnn",
  134. )
  135. decoder_choices2 = ClassChoices(
  136. "decoder2",
  137. classes=dict(
  138. transformer=TransformerDecoder,
  139. lightweight_conv=LightweightConvolutionTransformerDecoder,
  140. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  141. dynamic_conv=DynamicConvolutionTransformerDecoder,
  142. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  143. rnn=RNNDecoder,
  144. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  145. paraformer_decoder_sanm=ParaformerSANMDecoder,
  146. ),
  147. type_check=AbsDecoder,
  148. default="rnn",
  149. )
  150. predictor_choices = ClassChoices(
  151. name="predictor",
  152. classes=dict(
  153. cif_predictor=CifPredictor,
  154. ctc_predictor=None,
  155. cif_predictor_v2=CifPredictorV2,
  156. cif_predictor_v3=CifPredictorV3,
  157. ),
  158. default="cif_predictor",
  159. optional=True,
  160. )
  161. predictor_choices2 = ClassChoices(
  162. name="predictor2",
  163. classes=dict(
  164. cif_predictor=CifPredictor,
  165. ctc_predictor=None,
  166. cif_predictor_v2=CifPredictorV2,
  167. ),
  168. default="cif_predictor",
  169. optional=True,
  170. )
  171. stride_conv_choices = ClassChoices(
  172. name="stride_conv",
  173. classes=dict(
  174. stride_conv1d=Conv1dSubsampling
  175. ),
  176. default="stride_conv1d",
  177. optional=True,
  178. )
  179. rnnt_decoder_choices = ClassChoices(
  180. name="rnnt_decoder",
  181. classes=dict(
  182. rnnt=RNNTDecoder,
  183. ),
  184. default="rnnt",
  185. optional=True,
  186. )
  187. joint_network_choices = ClassChoices(
  188. name="joint_network",
  189. classes=dict(
  190. joint_network=JointNetwork,
  191. ),
  192. default="joint_network",
  193. optional=True,
  194. )
  195. class_choices_list = [
  196. # --frontend and --frontend_conf
  197. frontend_choices,
  198. # --specaug and --specaug_conf
  199. specaug_choices,
  200. # --normalize and --normalize_conf
  201. normalize_choices,
  202. # --model and --model_conf
  203. model_choices,
  204. # --encoder and --encoder_conf
  205. encoder_choices,
  206. # --decoder and --decoder_conf
  207. decoder_choices,
  208. # --predictor and --predictor_conf
  209. predictor_choices,
  210. # --encoder2 and --encoder2_conf
  211. encoder_choices2,
  212. # --decoder2 and --decoder2_conf
  213. decoder_choices2,
  214. # --predictor2 and --predictor2_conf
  215. predictor_choices2,
  216. # --stride_conv and --stride_conv_conf
  217. stride_conv_choices,
  218. # --rnnt_decoder and --rnnt_decoder_conf
  219. rnnt_decoder_choices,
  220. # --joint_network and --joint_network_conf
  221. joint_network_choices,
  222. ]
  223. def build_asr_model(args):
  224. # token_list
  225. if args.token_list is not None:
  226. with open(args.token_list) as f:
  227. token_list = [line.rstrip() for line in f]
  228. args.token_list = list(token_list)
  229. vocab_size = len(token_list)
  230. logging.info(f"Vocabulary size: {vocab_size}")
  231. else:
  232. vocab_size = None
  233. # frontend
  234. if args.input_size is None:
  235. frontend_class = frontend_choices.get_class(args.frontend)
  236. if args.frontend == 'wav_frontend':
  237. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  238. else:
  239. frontend = frontend_class(**args.frontend_conf)
  240. input_size = frontend.output_size()
  241. else:
  242. args.frontend = None
  243. args.frontend_conf = {}
  244. frontend = None
  245. input_size = args.input_size
  246. # data augmentation for spectrogram
  247. if args.specaug is not None:
  248. specaug_class = specaug_choices.get_class(args.specaug)
  249. specaug = specaug_class(**args.specaug_conf)
  250. else:
  251. specaug = None
  252. # normalization layer
  253. if args.normalize is not None:
  254. normalize_class = normalize_choices.get_class(args.normalize)
  255. normalize = normalize_class(**args.normalize_conf)
  256. else:
  257. normalize = None
  258. # encoder
  259. encoder_class = encoder_choices.get_class(args.encoder)
  260. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  261. # decoder
  262. decoder_class = decoder_choices.get_class(args.decoder)
  263. decoder = decoder_class(
  264. vocab_size=vocab_size,
  265. encoder_output_size=encoder.output_size(),
  266. **args.decoder_conf,
  267. )
  268. # ctc
  269. ctc = CTC(
  270. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  271. )
  272. if args.model in ["asr", "mfcca"]:
  273. model_class = model_choices.get_class(args.model)
  274. model = model_class(
  275. vocab_size=vocab_size,
  276. frontend=frontend,
  277. specaug=specaug,
  278. normalize=normalize,
  279. encoder=encoder,
  280. decoder=decoder,
  281. ctc=ctc,
  282. token_list=token_list,
  283. **args.model_conf,
  284. )
  285. elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
  286. # predictor
  287. predictor_class = predictor_choices.get_class(args.predictor)
  288. predictor = predictor_class(**args.predictor_conf)
  289. model_class = model_choices.get_class(args.model)
  290. model = model_class(
  291. vocab_size=vocab_size,
  292. frontend=frontend,
  293. specaug=specaug,
  294. normalize=normalize,
  295. encoder=encoder,
  296. decoder=decoder,
  297. ctc=ctc,
  298. token_list=token_list,
  299. predictor=predictor,
  300. **args.model_conf,
  301. )
  302. elif args.model == "uniasr":
  303. # stride_conv
  304. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  305. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
  306. odim=input_size + encoder.output_size())
  307. stride_conv_output_size = stride_conv.output_size()
  308. # encoder2
  309. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  310. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  311. # decoder2
  312. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  313. decoder2 = decoder_class2(
  314. vocab_size=vocab_size,
  315. encoder_output_size=encoder2.output_size(),
  316. **args.decoder2_conf,
  317. )
  318. # ctc2
  319. ctc2 = CTC(
  320. odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
  321. )
  322. # predictor
  323. predictor_class = predictor_choices.get_class(args.predictor)
  324. predictor = predictor_class(**args.predictor_conf)
  325. # predictor2
  326. predictor_class = predictor_choices2.get_class(args.predictor2)
  327. predictor2 = predictor_class(**args.predictor2_conf)
  328. model_class = model_choices.get_class(args.model)
  329. model = model_class(
  330. vocab_size=vocab_size,
  331. frontend=frontend,
  332. specaug=specaug,
  333. normalize=normalize,
  334. encoder=encoder,
  335. decoder=decoder,
  336. ctc=ctc,
  337. token_list=token_list,
  338. predictor=predictor,
  339. ctc2=ctc2,
  340. encoder2=encoder2,
  341. decoder2=decoder2,
  342. predictor2=predictor2,
  343. stride_conv=stride_conv,
  344. **args.model_conf,
  345. )
  346. elif args.model == "timestamp_prediction":
  347. model_class = model_choices.get_class(args.model)
  348. model = model_class(
  349. frontend=frontend,
  350. encoder=encoder,
  351. token_list=token_list,
  352. **args.model_conf,
  353. )
  354. elif args.model == "rnnt" or args.model == "rnnt_unified":
  355. # 5. Decoder
  356. encoder_output_size = encoder.output_size()
  357. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  358. decoder = rnnt_decoder_class(
  359. vocab_size,
  360. **args.rnnt_decoder_conf,
  361. )
  362. decoder_output_size = decoder.output_size
  363. if getattr(args, "decoder", None) is not None:
  364. att_decoder_class = decoder_choices.get_class(args.decoder)
  365. att_decoder = att_decoder_class(
  366. vocab_size=vocab_size,
  367. encoder_output_size=encoder_output_size,
  368. **args.decoder_conf,
  369. )
  370. else:
  371. att_decoder = None
  372. # 6. Joint Network
  373. joint_network = JointNetwork(
  374. vocab_size,
  375. encoder_output_size,
  376. decoder_output_size,
  377. **args.joint_network_conf,
  378. )
  379. model_class = model_choices.get_class(args.model)
  380. # 7. Build model
  381. model = model_class(
  382. vocab_size=vocab_size,
  383. token_list=token_list,
  384. frontend=frontend,
  385. specaug=specaug,
  386. normalize=normalize,
  387. encoder=encoder,
  388. decoder=decoder,
  389. att_decoder=att_decoder,
  390. joint_network=joint_network,
  391. **args.model_conf,
  392. )
  393. else:
  394. raise NotImplementedError("Not supported model: {}".format(args.model))
  395. # initialize
  396. if args.init is not None:
  397. initialize(model, args.init)
  398. return model