asr.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.ctc import CTC
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.decoder.rnn_decoder import RNNDecoder
  25. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  26. from funasr.models.decoder.transformer_decoder import (
  27. DynamicConvolution2DTransformerDecoder, # noqa: H301
  28. )
  29. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolution2DTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import (
  34. LightweightConvolutionTransformerDecoder, # noqa: H301
  35. )
  36. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  37. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  38. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  39. from funasr.models.e2e_asr import ESPnetASRModel
  40. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
  41. from funasr.models.e2e_tp import TimestampPredictor
  42. from funasr.models.e2e_asr_mfcca import MFCCA
  43. from funasr.models.e2e_uni_asr import UniASR
  44. from funasr.models.encoder.abs_encoder import AbsEncoder
  45. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  46. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  47. from funasr.models.encoder.rnn_encoder import RNNEncoder
  48. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  49. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  50. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  51. from funasr.models.frontend.abs_frontend import AbsFrontend
  52. from funasr.models.frontend.default import DefaultFrontend
  53. from funasr.models.frontend.default import MultiChannelFrontend
  54. from funasr.models.frontend.fused import FusedFrontends
  55. from funasr.models.frontend.s3prl import S3prlFrontend
  56. from funasr.models.frontend.wav_frontend import WavFrontend
  57. from funasr.models.frontend.windowing import SlidingWindow
  58. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  59. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  60. HuggingFaceTransformersPostEncoder, # noqa: H301
  61. )
  62. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  63. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  64. from funasr.models.preencoder.linear import LinearProjection
  65. from funasr.models.preencoder.sinc import LightweightSincConvs
  66. from funasr.models.specaug.abs_specaug import AbsSpecAug
  67. from funasr.models.specaug.specaug import SpecAug
  68. from funasr.models.specaug.specaug import SpecAugLFR
  69. from funasr.modules.subsampling import Conv1dSubsampling
  70. from funasr.tasks.abs_task import AbsTask
  71. from funasr.text.phoneme_tokenizer import g2p_choices
  72. from funasr.torch_utils.initialize import initialize
  73. from funasr.train.abs_espnet_model import AbsESPnetModel
  74. from funasr.train.class_choices import ClassChoices
  75. from funasr.train.trainer import Trainer
  76. from funasr.utils.get_default_kwargs import get_default_kwargs
  77. from funasr.utils.nested_dict_action import NestedDictAction
  78. from funasr.utils.types import float_or_none
  79. from funasr.utils.types import int_or_none
  80. from funasr.utils.types import str2bool
  81. from funasr.utils.types import str_or_none
  82. frontend_choices = ClassChoices(
  83. name="frontend",
  84. classes=dict(
  85. default=DefaultFrontend,
  86. sliding_window=SlidingWindow,
  87. s3prl=S3prlFrontend,
  88. fused=FusedFrontends,
  89. wav_frontend=WavFrontend,
  90. multichannelfrontend=MultiChannelFrontend,
  91. ),
  92. type_check=AbsFrontend,
  93. default="default",
  94. )
  95. specaug_choices = ClassChoices(
  96. name="specaug",
  97. classes=dict(
  98. specaug=SpecAug,
  99. specaug_lfr=SpecAugLFR,
  100. ),
  101. type_check=AbsSpecAug,
  102. default=None,
  103. optional=True,
  104. )
  105. normalize_choices = ClassChoices(
  106. "normalize",
  107. classes=dict(
  108. global_mvn=GlobalMVN,
  109. utterance_mvn=UtteranceMVN,
  110. ),
  111. type_check=AbsNormalize,
  112. default=None,
  113. optional=True,
  114. )
  115. model_choices = ClassChoices(
  116. "model",
  117. classes=dict(
  118. asr=ESPnetASRModel,
  119. uniasr=UniASR,
  120. paraformer=Paraformer,
  121. paraformer_bert=ParaformerBert,
  122. bicif_paraformer=BiCifParaformer,
  123. contextual_paraformer=ContextualParaformer,
  124. mfcca=MFCCA,
  125. timestamp_prediction=TimestampPredictor,
  126. ),
  127. type_check=AbsESPnetModel,
  128. default="asr",
  129. )
  130. preencoder_choices = ClassChoices(
  131. name="preencoder",
  132. classes=dict(
  133. sinc=LightweightSincConvs,
  134. linear=LinearProjection,
  135. ),
  136. type_check=AbsPreEncoder,
  137. default=None,
  138. optional=True,
  139. )
  140. encoder_choices = ClassChoices(
  141. "encoder",
  142. classes=dict(
  143. conformer=ConformerEncoder,
  144. transformer=TransformerEncoder,
  145. rnn=RNNEncoder,
  146. sanm=SANMEncoder,
  147. sanm_chunk_opt=SANMEncoderChunkOpt,
  148. data2vec_encoder=Data2VecEncoder,
  149. mfcca_enc=MFCCAEncoder,
  150. ),
  151. type_check=AbsEncoder,
  152. default="rnn",
  153. )
  154. encoder_choices2 = ClassChoices(
  155. "encoder2",
  156. classes=dict(
  157. conformer=ConformerEncoder,
  158. transformer=TransformerEncoder,
  159. rnn=RNNEncoder,
  160. sanm=SANMEncoder,
  161. sanm_chunk_opt=SANMEncoderChunkOpt,
  162. ),
  163. type_check=AbsEncoder,
  164. default="rnn",
  165. )
  166. postencoder_choices = ClassChoices(
  167. name="postencoder",
  168. classes=dict(
  169. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  170. ),
  171. type_check=AbsPostEncoder,
  172. default=None,
  173. optional=True,
  174. )
  175. decoder_choices = ClassChoices(
  176. "decoder",
  177. classes=dict(
  178. transformer=TransformerDecoder,
  179. lightweight_conv=LightweightConvolutionTransformerDecoder,
  180. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  181. dynamic_conv=DynamicConvolutionTransformerDecoder,
  182. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  183. rnn=RNNDecoder,
  184. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  185. paraformer_decoder_sanm=ParaformerSANMDecoder,
  186. paraformer_decoder_san=ParaformerDecoderSAN,
  187. contextual_paraformer_decoder=ContextualParaformerDecoder,
  188. ),
  189. type_check=AbsDecoder,
  190. default="rnn",
  191. )
  192. decoder_choices2 = ClassChoices(
  193. "decoder2",
  194. classes=dict(
  195. transformer=TransformerDecoder,
  196. lightweight_conv=LightweightConvolutionTransformerDecoder,
  197. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  198. dynamic_conv=DynamicConvolutionTransformerDecoder,
  199. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  200. rnn=RNNDecoder,
  201. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  202. paraformer_decoder_sanm=ParaformerSANMDecoder,
  203. ),
  204. type_check=AbsDecoder,
  205. default="rnn",
  206. )
  207. predictor_choices = ClassChoices(
  208. name="predictor",
  209. classes=dict(
  210. cif_predictor=CifPredictor,
  211. ctc_predictor=None,
  212. cif_predictor_v2=CifPredictorV2,
  213. cif_predictor_v3=CifPredictorV3,
  214. ),
  215. type_check=None,
  216. default="cif_predictor",
  217. optional=True,
  218. )
  219. predictor_choices2 = ClassChoices(
  220. name="predictor2",
  221. classes=dict(
  222. cif_predictor=CifPredictor,
  223. ctc_predictor=None,
  224. cif_predictor_v2=CifPredictorV2,
  225. ),
  226. type_check=None,
  227. default="cif_predictor",
  228. optional=True,
  229. )
  230. stride_conv_choices = ClassChoices(
  231. name="stride_conv",
  232. classes=dict(
  233. stride_conv1d=Conv1dSubsampling
  234. ),
  235. type_check=None,
  236. default="stride_conv1d",
  237. optional=True,
  238. )
  239. class ASRTask(AbsTask):
  240. # If you need more than one optimizers, change this value
  241. num_optimizers: int = 1
  242. # Add variable objects configurations
  243. class_choices_list = [
  244. # --frontend and --frontend_conf
  245. frontend_choices,
  246. # --specaug and --specaug_conf
  247. specaug_choices,
  248. # --normalize and --normalize_conf
  249. normalize_choices,
  250. # --model and --model_conf
  251. model_choices,
  252. # --preencoder and --preencoder_conf
  253. preencoder_choices,
  254. # --encoder and --encoder_conf
  255. encoder_choices,
  256. # --postencoder and --postencoder_conf
  257. postencoder_choices,
  258. # --decoder and --decoder_conf
  259. decoder_choices,
  260. ]
  261. # If you need to modify train() or eval() procedures, change Trainer class here
  262. trainer = Trainer
  263. @classmethod
  264. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  265. group = parser.add_argument_group(description="Task related")
  266. # NOTE(kamo): add_arguments(..., required=True) can't be used
  267. # to provide --print_config mode. Instead of it, do as
  268. # required = parser.get_default("required")
  269. # required += ["token_list"]
  270. group.add_argument(
  271. "--token_list",
  272. type=str_or_none,
  273. default=None,
  274. help="A text mapping int-id to token",
  275. )
  276. group.add_argument(
  277. "--split_with_space",
  278. type=str2bool,
  279. default=True,
  280. help="whether to split text using <space>",
  281. )
  282. group.add_argument(
  283. "--seg_dict_file",
  284. type=str,
  285. default=None,
  286. help="seg_dict_file for text processing",
  287. )
  288. group.add_argument(
  289. "--init",
  290. type=lambda x: str_or_none(x.lower()),
  291. default=None,
  292. help="The initialization method",
  293. choices=[
  294. "chainer",
  295. "xavier_uniform",
  296. "xavier_normal",
  297. "kaiming_uniform",
  298. "kaiming_normal",
  299. None,
  300. ],
  301. )
  302. group.add_argument(
  303. "--input_size",
  304. type=int_or_none,
  305. default=None,
  306. help="The number of input dimension of the feature",
  307. )
  308. group.add_argument(
  309. "--ctc_conf",
  310. action=NestedDictAction,
  311. default=get_default_kwargs(CTC),
  312. help="The keyword arguments for CTC class.",
  313. )
  314. group.add_argument(
  315. "--joint_net_conf",
  316. action=NestedDictAction,
  317. default=None,
  318. help="The keyword arguments for joint network class.",
  319. )
  320. group = parser.add_argument_group(description="Preprocess related")
  321. group.add_argument(
  322. "--use_preprocessor",
  323. type=str2bool,
  324. default=True,
  325. help="Apply preprocessing to data or not",
  326. )
  327. group.add_argument(
  328. "--token_type",
  329. type=str,
  330. default="bpe",
  331. choices=["bpe", "char", "word", "phn"],
  332. help="The text will be tokenized " "in the specified level token",
  333. )
  334. group.add_argument(
  335. "--bpemodel",
  336. type=str_or_none,
  337. default=None,
  338. help="The model file of sentencepiece",
  339. )
  340. parser.add_argument(
  341. "--non_linguistic_symbols",
  342. type=str_or_none,
  343. default=None,
  344. help="non_linguistic_symbols file path",
  345. )
  346. parser.add_argument(
  347. "--cleaner",
  348. type=str_or_none,
  349. choices=[None, "tacotron", "jaconv", "vietnamese"],
  350. default=None,
  351. help="Apply text cleaning",
  352. )
  353. parser.add_argument(
  354. "--g2p",
  355. type=str_or_none,
  356. choices=g2p_choices,
  357. default=None,
  358. help="Specify g2p method if --token_type=phn",
  359. )
  360. parser.add_argument(
  361. "--speech_volume_normalize",
  362. type=float_or_none,
  363. default=None,
  364. help="Scale the maximum amplitude to the given value.",
  365. )
  366. parser.add_argument(
  367. "--rir_scp",
  368. type=str_or_none,
  369. default=None,
  370. help="The file path of rir scp file.",
  371. )
  372. parser.add_argument(
  373. "--rir_apply_prob",
  374. type=float,
  375. default=1.0,
  376. help="THe probability for applying RIR convolution.",
  377. )
  378. parser.add_argument(
  379. "--cmvn_file",
  380. type=str_or_none,
  381. default=None,
  382. help="The file path of noise scp file.",
  383. )
  384. parser.add_argument(
  385. "--noise_scp",
  386. type=str_or_none,
  387. default=None,
  388. help="The file path of noise scp file.",
  389. )
  390. parser.add_argument(
  391. "--noise_apply_prob",
  392. type=float,
  393. default=1.0,
  394. help="The probability applying Noise adding.",
  395. )
  396. parser.add_argument(
  397. "--noise_db_range",
  398. type=str,
  399. default="13_15",
  400. help="The range of noise decibel level.",
  401. )
  402. for class_choices in cls.class_choices_list:
  403. # Append --<name> and --<name>_conf.
  404. # e.g. --encoder and --encoder_conf
  405. class_choices.add_arguments(group)
  406. @classmethod
  407. def build_collate_fn(
  408. cls, args: argparse.Namespace, train: bool
  409. ) -> Callable[
  410. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  411. Tuple[List[str], Dict[str, torch.Tensor]],
  412. ]:
  413. assert check_argument_types()
  414. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  415. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  416. @classmethod
  417. def build_preprocess_fn(
  418. cls, args: argparse.Namespace, train: bool
  419. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  420. assert check_argument_types()
  421. if args.use_preprocessor:
  422. retval = CommonPreprocessor(
  423. train=train,
  424. token_type=args.token_type,
  425. token_list=args.token_list,
  426. bpemodel=args.bpemodel,
  427. non_linguistic_symbols=args.non_linguistic_symbols,
  428. text_cleaner=args.cleaner,
  429. g2p_type=args.g2p,
  430. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  431. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  432. # NOTE(kamo): Check attribute existence for backward compatibility
  433. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  434. rir_apply_prob=args.rir_apply_prob
  435. if hasattr(args, "rir_apply_prob")
  436. else 1.0,
  437. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  438. noise_apply_prob=args.noise_apply_prob
  439. if hasattr(args, "noise_apply_prob")
  440. else 1.0,
  441. noise_db_range=args.noise_db_range
  442. if hasattr(args, "noise_db_range")
  443. else "13_15",
  444. speech_volume_normalize=args.speech_volume_normalize
  445. if hasattr(args, "rir_scp")
  446. else None,
  447. )
  448. else:
  449. retval = None
  450. assert check_return_type(retval)
  451. return retval
  452. @classmethod
  453. def required_data_names(
  454. cls, train: bool = True, inference: bool = False
  455. ) -> Tuple[str, ...]:
  456. if not inference:
  457. retval = ("speech", "text")
  458. else:
  459. # Recognition mode
  460. retval = ("speech",)
  461. return retval
  462. @classmethod
  463. def optional_data_names(
  464. cls, train: bool = True, inference: bool = False
  465. ) -> Tuple[str, ...]:
  466. retval = ()
  467. assert check_return_type(retval)
  468. return retval
  469. @classmethod
  470. def build_model(cls, args: argparse.Namespace):
  471. assert check_argument_types()
  472. if isinstance(args.token_list, str):
  473. with open(args.token_list, encoding="utf-8") as f:
  474. token_list = [line.rstrip() for line in f]
  475. # Overwriting token_list to keep it as "portable".
  476. args.token_list = list(token_list)
  477. elif isinstance(args.token_list, (tuple, list)):
  478. token_list = list(args.token_list)
  479. else:
  480. raise RuntimeError("token_list must be str or list")
  481. vocab_size = len(token_list)
  482. logging.info(f"Vocabulary size: {vocab_size}")
  483. # 1. frontend
  484. if args.input_size is None:
  485. # Extract features in the model
  486. frontend_class = frontend_choices.get_class(args.frontend)
  487. if args.frontend == 'wav_frontend':
  488. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  489. else:
  490. frontend = frontend_class(**args.frontend_conf)
  491. input_size = frontend.output_size()
  492. else:
  493. # Give features from data-loader
  494. args.frontend = None
  495. args.frontend_conf = {}
  496. frontend = None
  497. input_size = args.input_size
  498. # 2. Data augmentation for spectrogram
  499. if args.specaug is not None:
  500. specaug_class = specaug_choices.get_class(args.specaug)
  501. specaug = specaug_class(**args.specaug_conf)
  502. else:
  503. specaug = None
  504. # 3. Normalization layer
  505. if args.normalize is not None:
  506. normalize_class = normalize_choices.get_class(args.normalize)
  507. normalize = normalize_class(**args.normalize_conf)
  508. else:
  509. normalize = None
  510. # 4. Pre-encoder input block
  511. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  512. if getattr(args, "preencoder", None) is not None:
  513. preencoder_class = preencoder_choices.get_class(args.preencoder)
  514. preencoder = preencoder_class(**args.preencoder_conf)
  515. input_size = preencoder.output_size()
  516. else:
  517. preencoder = None
  518. # 5. Encoder
  519. encoder_class = encoder_choices.get_class(args.encoder)
  520. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  521. # 6. Post-encoder block
  522. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  523. encoder_output_size = encoder.output_size()
  524. if getattr(args, "postencoder", None) is not None:
  525. postencoder_class = postencoder_choices.get_class(args.postencoder)
  526. postencoder = postencoder_class(
  527. input_size=encoder_output_size, **args.postencoder_conf
  528. )
  529. encoder_output_size = postencoder.output_size()
  530. else:
  531. postencoder = None
  532. # 7. Decoder
  533. decoder_class = decoder_choices.get_class(args.decoder)
  534. decoder = decoder_class(
  535. vocab_size=vocab_size,
  536. encoder_output_size=encoder_output_size,
  537. **args.decoder_conf,
  538. )
  539. # 8. CTC
  540. ctc = CTC(
  541. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  542. )
  543. # 9. Build model
  544. try:
  545. model_class = model_choices.get_class(args.model)
  546. except AttributeError:
  547. model_class = model_choices.get_class("asr")
  548. model = model_class(
  549. vocab_size=vocab_size,
  550. frontend=frontend,
  551. specaug=specaug,
  552. normalize=normalize,
  553. preencoder=preencoder,
  554. encoder=encoder,
  555. postencoder=postencoder,
  556. decoder=decoder,
  557. ctc=ctc,
  558. token_list=token_list,
  559. **args.model_conf,
  560. )
  561. # 10. Initialize
  562. if args.init is not None:
  563. initialize(model, args.init)
  564. assert check_return_type(model)
  565. return model
  566. class ASRTaskUniASR(ASRTask):
  567. # If you need more than one optimizers, change this value
  568. num_optimizers: int = 1
  569. # Add variable objects configurations
  570. class_choices_list = [
  571. # --frontend and --frontend_conf
  572. frontend_choices,
  573. # --specaug and --specaug_conf
  574. specaug_choices,
  575. # --normalize and --normalize_conf
  576. normalize_choices,
  577. # --model and --model_conf
  578. model_choices,
  579. # --preencoder and --preencoder_conf
  580. preencoder_choices,
  581. # --encoder and --encoder_conf
  582. encoder_choices,
  583. # --postencoder and --postencoder_conf
  584. postencoder_choices,
  585. # --decoder and --decoder_conf
  586. decoder_choices,
  587. # --predictor and --predictor_conf
  588. predictor_choices,
  589. # --encoder2 and --encoder2_conf
  590. encoder_choices2,
  591. # --decoder2 and --decoder2_conf
  592. decoder_choices2,
  593. # --predictor2 and --predictor2_conf
  594. predictor_choices2,
  595. # --stride_conv and --stride_conv_conf
  596. stride_conv_choices,
  597. ]
  598. # If you need to modify train() or eval() procedures, change Trainer class here
  599. trainer = Trainer
  600. @classmethod
  601. def build_model(cls, args: argparse.Namespace):
  602. assert check_argument_types()
  603. if isinstance(args.token_list, str):
  604. with open(args.token_list, encoding="utf-8") as f:
  605. token_list = [line.rstrip() for line in f]
  606. # Overwriting token_list to keep it as "portable".
  607. args.token_list = list(token_list)
  608. elif isinstance(args.token_list, (tuple, list)):
  609. token_list = list(args.token_list)
  610. else:
  611. raise RuntimeError("token_list must be str or list")
  612. vocab_size = len(token_list)
  613. logging.info(f"Vocabulary size: {vocab_size}")
  614. # 1. frontend
  615. if args.input_size is None:
  616. # Extract features in the model
  617. frontend_class = frontend_choices.get_class(args.frontend)
  618. if args.frontend == 'wav_frontend':
  619. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  620. else:
  621. frontend = frontend_class(**args.frontend_conf)
  622. input_size = frontend.output_size()
  623. else:
  624. # Give features from data-loader
  625. args.frontend = None
  626. args.frontend_conf = {}
  627. frontend = None
  628. input_size = args.input_size
  629. # 2. Data augmentation for spectrogram
  630. if args.specaug is not None:
  631. specaug_class = specaug_choices.get_class(args.specaug)
  632. specaug = specaug_class(**args.specaug_conf)
  633. else:
  634. specaug = None
  635. # 3. Normalization layer
  636. if args.normalize is not None:
  637. normalize_class = normalize_choices.get_class(args.normalize)
  638. normalize = normalize_class(**args.normalize_conf)
  639. else:
  640. normalize = None
  641. # 4. Pre-encoder input block
  642. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  643. if getattr(args, "preencoder", None) is not None:
  644. preencoder_class = preencoder_choices.get_class(args.preencoder)
  645. preencoder = preencoder_class(**args.preencoder_conf)
  646. input_size = preencoder.output_size()
  647. else:
  648. preencoder = None
  649. # 5. Encoder
  650. encoder_class = encoder_choices.get_class(args.encoder)
  651. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  652. encoder_output_size = encoder.output_size()
  653. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  654. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder_output_size,
  655. odim=input_size + encoder_output_size)
  656. stride_conv_output_size = stride_conv.output_size()
  657. # 6. Encoder2
  658. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  659. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  660. # 7. Post-encoder block
  661. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  662. encoder_output_size2 = encoder2.output_size()
  663. if getattr(args, "postencoder", None) is not None:
  664. postencoder_class = postencoder_choices.get_class(args.postencoder)
  665. postencoder = postencoder_class(
  666. input_size=encoder_output_size, **args.postencoder_conf
  667. )
  668. encoder_output_size = postencoder.output_size()
  669. else:
  670. postencoder = None
  671. # 8. Decoder & Decoder2
  672. decoder_class = decoder_choices.get_class(args.decoder)
  673. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  674. decoder = decoder_class(
  675. vocab_size=vocab_size,
  676. encoder_output_size=encoder_output_size,
  677. **args.decoder_conf,
  678. )
  679. decoder2 = decoder_class2(
  680. vocab_size=vocab_size,
  681. encoder_output_size=encoder_output_size2,
  682. **args.decoder2_conf,
  683. )
  684. # 9. CTC
  685. ctc = CTC(
  686. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  687. )
  688. ctc2 = CTC(
  689. odim=vocab_size, encoder_output_size=encoder_output_size2, **args.ctc_conf
  690. )
  691. # 10. Predictor
  692. predictor_class = predictor_choices.get_class(args.predictor)
  693. predictor = predictor_class(**args.predictor_conf)
  694. predictor_class = predictor_choices2.get_class(args.predictor2)
  695. predictor2 = predictor_class(**args.predictor2_conf)
  696. # 11. Build model
  697. try:
  698. model_class = model_choices.get_class(args.model)
  699. except AttributeError:
  700. model_class = model_choices.get_class("asr")
  701. model = model_class(
  702. vocab_size=vocab_size,
  703. frontend=frontend,
  704. specaug=specaug,
  705. normalize=normalize,
  706. preencoder=preencoder,
  707. encoder=encoder,
  708. postencoder=postencoder,
  709. decoder=decoder,
  710. ctc=ctc,
  711. token_list=token_list,
  712. predictor=predictor,
  713. ctc2=ctc2,
  714. encoder2=encoder2,
  715. decoder2=decoder2,
  716. predictor2=predictor2,
  717. stride_conv=stride_conv,
  718. **args.model_conf,
  719. )
  720. # 12. Initialize
  721. if args.init is not None:
  722. initialize(model, args.init)
  723. assert check_return_type(model)
  724. return model
  725. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  726. @classmethod
  727. def build_model_from_file(
  728. cls,
  729. config_file: Union[Path, str] = None,
  730. model_file: Union[Path, str] = None,
  731. cmvn_file: Union[Path, str] = None,
  732. device: str = "cpu",
  733. ):
  734. """Build model from the files.
  735. This method is used for inference or fine-tuning.
  736. Args:
  737. config_file: The yaml file saved when training.
  738. model_file: The model file saved when training.
  739. device: Device type, "cpu", "cuda", or "cuda:N".
  740. """
  741. assert check_argument_types()
  742. if config_file is None:
  743. assert model_file is not None, (
  744. "The argument 'model_file' must be provided "
  745. "if the argument 'config_file' is not specified."
  746. )
  747. config_file = Path(model_file).parent / "config.yaml"
  748. else:
  749. config_file = Path(config_file)
  750. with config_file.open("r", encoding="utf-8") as f:
  751. args = yaml.safe_load(f)
  752. if cmvn_file is not None:
  753. args["cmvn_file"] = cmvn_file
  754. args = argparse.Namespace(**args)
  755. model = cls.build_model(args)
  756. if not isinstance(model, AbsESPnetModel):
  757. raise RuntimeError(
  758. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  759. )
  760. model.to(device)
  761. model_dict = dict()
  762. model_name_pth = None
  763. if model_file is not None:
  764. logging.info("model_file is {}".format(model_file))
  765. if device == "cuda":
  766. device = f"cuda:{torch.cuda.current_device()}"
  767. model_dir = os.path.dirname(model_file)
  768. model_name = os.path.basename(model_file)
  769. if "model.ckpt-" in model_name or ".bin" in model_name:
  770. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  771. '.pb')) if ".bin" in model_name else os.path.join(
  772. model_dir, "{}.pb".format(model_name))
  773. if os.path.exists(model_name_pth):
  774. logging.info("model_file is load from pth: {}".format(model_name_pth))
  775. model_dict = torch.load(model_name_pth, map_location=device)
  776. else:
  777. model_dict = cls.convert_tf2torch(model, model_file)
  778. model.load_state_dict(model_dict)
  779. else:
  780. model_dict = torch.load(model_file, map_location=device)
  781. model.load_state_dict(model_dict)
  782. if model_name_pth is not None and not os.path.exists(model_name_pth):
  783. torch.save(model_dict, model_name_pth)
  784. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  785. return model, args
  786. @classmethod
  787. def convert_tf2torch(
  788. cls,
  789. model,
  790. ckpt,
  791. ):
  792. logging.info("start convert tf model to torch model")
  793. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  794. var_dict_tf = load_tf_dict(ckpt)
  795. var_dict_torch = model.state_dict()
  796. var_dict_torch_update = dict()
  797. # encoder
  798. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  799. var_dict_torch_update.update(var_dict_torch_update_local)
  800. # predictor
  801. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  802. var_dict_torch_update.update(var_dict_torch_update_local)
  803. # decoder
  804. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  805. var_dict_torch_update.update(var_dict_torch_update_local)
  806. # encoder2
  807. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  808. var_dict_torch_update.update(var_dict_torch_update_local)
  809. # predictor2
  810. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  811. var_dict_torch_update.update(var_dict_torch_update_local)
  812. # decoder2
  813. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  814. var_dict_torch_update.update(var_dict_torch_update_local)
  815. # stride_conv
  816. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  817. var_dict_torch_update.update(var_dict_torch_update_local)
  818. return var_dict_torch_update
  819. class ASRTaskParaformer(ASRTask):
  820. # If you need more than one optimizers, change this value
  821. num_optimizers: int = 1
  822. # Add variable objects configurations
  823. class_choices_list = [
  824. # --frontend and --frontend_conf
  825. frontend_choices,
  826. # --specaug and --specaug_conf
  827. specaug_choices,
  828. # --normalize and --normalize_conf
  829. normalize_choices,
  830. # --model and --model_conf
  831. model_choices,
  832. # --preencoder and --preencoder_conf
  833. preencoder_choices,
  834. # --encoder and --encoder_conf
  835. encoder_choices,
  836. # --postencoder and --postencoder_conf
  837. postencoder_choices,
  838. # --decoder and --decoder_conf
  839. decoder_choices,
  840. # --predictor and --predictor_conf
  841. predictor_choices,
  842. ]
  843. # If you need to modify train() or eval() procedures, change Trainer class here
  844. trainer = Trainer
  845. @classmethod
  846. def build_model(cls, args: argparse.Namespace):
  847. assert check_argument_types()
  848. if isinstance(args.token_list, str):
  849. with open(args.token_list, encoding="utf-8") as f:
  850. token_list = [line.rstrip() for line in f]
  851. # Overwriting token_list to keep it as "portable".
  852. args.token_list = list(token_list)
  853. elif isinstance(args.token_list, (tuple, list)):
  854. token_list = list(args.token_list)
  855. else:
  856. raise RuntimeError("token_list must be str or list")
  857. vocab_size = len(token_list)
  858. logging.info(f"Vocabulary size: {vocab_size}")
  859. # 1. frontend
  860. if args.input_size is None:
  861. # Extract features in the model
  862. frontend_class = frontend_choices.get_class(args.frontend)
  863. if args.frontend == 'wav_frontend':
  864. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  865. else:
  866. frontend = frontend_class(**args.frontend_conf)
  867. input_size = frontend.output_size()
  868. else:
  869. # Give features from data-loader
  870. args.frontend = None
  871. args.frontend_conf = {}
  872. frontend = None
  873. input_size = args.input_size
  874. # 2. Data augmentation for spectrogram
  875. if args.specaug is not None:
  876. specaug_class = specaug_choices.get_class(args.specaug)
  877. specaug = specaug_class(**args.specaug_conf)
  878. else:
  879. specaug = None
  880. # 3. Normalization layer
  881. if args.normalize is not None:
  882. normalize_class = normalize_choices.get_class(args.normalize)
  883. normalize = normalize_class(**args.normalize_conf)
  884. else:
  885. normalize = None
  886. # 4. Pre-encoder input block
  887. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  888. if getattr(args, "preencoder", None) is not None:
  889. preencoder_class = preencoder_choices.get_class(args.preencoder)
  890. preencoder = preencoder_class(**args.preencoder_conf)
  891. input_size = preencoder.output_size()
  892. else:
  893. preencoder = None
  894. # 5. Encoder
  895. encoder_class = encoder_choices.get_class(args.encoder)
  896. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  897. # 6. Post-encoder block
  898. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  899. encoder_output_size = encoder.output_size()
  900. if getattr(args, "postencoder", None) is not None:
  901. postencoder_class = postencoder_choices.get_class(args.postencoder)
  902. postencoder = postencoder_class(
  903. input_size=encoder_output_size, **args.postencoder_conf
  904. )
  905. encoder_output_size = postencoder.output_size()
  906. else:
  907. postencoder = None
  908. # 7. Decoder
  909. decoder_class = decoder_choices.get_class(args.decoder)
  910. decoder = decoder_class(
  911. vocab_size=vocab_size,
  912. encoder_output_size=encoder_output_size,
  913. **args.decoder_conf,
  914. )
  915. # 8. CTC
  916. ctc = CTC(
  917. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  918. )
  919. # 9. Predictor
  920. predictor_class = predictor_choices.get_class(args.predictor)
  921. predictor = predictor_class(**args.predictor_conf)
  922. # 10. Build model
  923. try:
  924. model_class = model_choices.get_class(args.model)
  925. except AttributeError:
  926. model_class = model_choices.get_class("asr")
  927. model = model_class(
  928. vocab_size=vocab_size,
  929. frontend=frontend,
  930. specaug=specaug,
  931. normalize=normalize,
  932. preencoder=preencoder,
  933. encoder=encoder,
  934. postencoder=postencoder,
  935. decoder=decoder,
  936. ctc=ctc,
  937. token_list=token_list,
  938. predictor=predictor,
  939. **args.model_conf,
  940. )
  941. # 11. Initialize
  942. if args.init is not None:
  943. initialize(model, args.init)
  944. assert check_return_type(model)
  945. return model
  946. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  947. @classmethod
  948. def build_model_from_file(
  949. cls,
  950. config_file: Union[Path, str] = None,
  951. model_file: Union[Path, str] = None,
  952. cmvn_file: Union[Path, str] = None,
  953. device: str = "cpu",
  954. ):
  955. """Build model from the files.
  956. This method is used for inference or fine-tuning.
  957. Args:
  958. config_file: The yaml file saved when training.
  959. model_file: The model file saved when training.
  960. device: Device type, "cpu", "cuda", or "cuda:N".
  961. """
  962. assert check_argument_types()
  963. if config_file is None:
  964. assert model_file is not None, (
  965. "The argument 'model_file' must be provided "
  966. "if the argument 'config_file' is not specified."
  967. )
  968. config_file = Path(model_file).parent / "config.yaml"
  969. else:
  970. config_file = Path(config_file)
  971. with config_file.open("r", encoding="utf-8") as f:
  972. args = yaml.safe_load(f)
  973. if cmvn_file is not None:
  974. args["cmvn_file"] = cmvn_file
  975. args = argparse.Namespace(**args)
  976. model = cls.build_model(args)
  977. if not isinstance(model, AbsESPnetModel):
  978. raise RuntimeError(
  979. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  980. )
  981. model.to(device)
  982. model_dict = dict()
  983. model_name_pth = None
  984. if model_file is not None:
  985. logging.info("model_file is {}".format(model_file))
  986. if device == "cuda":
  987. device = f"cuda:{torch.cuda.current_device()}"
  988. model_dir = os.path.dirname(model_file)
  989. model_name = os.path.basename(model_file)
  990. if "model.ckpt-" in model_name or ".bin" in model_name:
  991. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  992. '.pb')) if ".bin" in model_name else os.path.join(
  993. model_dir, "{}.pb".format(model_name))
  994. if os.path.exists(model_name_pth):
  995. logging.info("model_file is load from pth: {}".format(model_name_pth))
  996. model_dict = torch.load(model_name_pth, map_location=device)
  997. else:
  998. model_dict = cls.convert_tf2torch(model, model_file)
  999. model.load_state_dict(model_dict)
  1000. else:
  1001. model_dict = torch.load(model_file, map_location=device)
  1002. model.load_state_dict(model_dict)
  1003. if model_name_pth is not None and not os.path.exists(model_name_pth):
  1004. torch.save(model_dict, model_name_pth)
  1005. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  1006. model.to(device)
  1007. return model, args
  1008. @classmethod
  1009. def convert_tf2torch(
  1010. cls,
  1011. model,
  1012. ckpt,
  1013. ):
  1014. logging.info("start convert tf model to torch model")
  1015. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  1016. var_dict_tf = load_tf_dict(ckpt)
  1017. var_dict_torch = model.state_dict()
  1018. var_dict_torch_update = dict()
  1019. # encoder
  1020. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1021. var_dict_torch_update.update(var_dict_torch_update_local)
  1022. # predictor
  1023. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  1024. var_dict_torch_update.update(var_dict_torch_update_local)
  1025. # decoder
  1026. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1027. var_dict_torch_update.update(var_dict_torch_update_local)
  1028. # bias_encoder
  1029. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  1030. var_dict_torch_update.update(var_dict_torch_update_local)
  1031. return var_dict_torch_update
  1032. class ASRTaskMFCCA(ASRTask):
  1033. # If you need more than one optimizers, change this value
  1034. num_optimizers: int = 1
  1035. # Add variable objects configurations
  1036. class_choices_list = [
  1037. # --frontend and --frontend_conf
  1038. frontend_choices,
  1039. # --specaug and --specaug_conf
  1040. specaug_choices,
  1041. # --normalize and --normalize_conf
  1042. normalize_choices,
  1043. # --model and --model_conf
  1044. model_choices,
  1045. # --preencoder and --preencoder_conf
  1046. preencoder_choices,
  1047. # --encoder and --encoder_conf
  1048. encoder_choices,
  1049. # --decoder and --decoder_conf
  1050. decoder_choices,
  1051. ]
  1052. # If you need to modify train() or eval() procedures, change Trainer class here
  1053. trainer = Trainer
  1054. @classmethod
  1055. def build_model(cls, args: argparse.Namespace):
  1056. assert check_argument_types()
  1057. if isinstance(args.token_list, str):
  1058. with open(args.token_list, encoding="utf-8") as f:
  1059. token_list = [line.rstrip() for line in f]
  1060. # Overwriting token_list to keep it as "portable".
  1061. args.token_list = list(token_list)
  1062. elif isinstance(args.token_list, (tuple, list)):
  1063. token_list = list(args.token_list)
  1064. else:
  1065. raise RuntimeError("token_list must be str or list")
  1066. vocab_size = len(token_list)
  1067. logging.info(f"Vocabulary size: {vocab_size}")
  1068. # 1. frontend
  1069. if args.input_size is None:
  1070. # Extract features in the model
  1071. frontend_class = frontend_choices.get_class(args.frontend)
  1072. if args.frontend == 'wav_frontend':
  1073. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1074. else:
  1075. frontend = frontend_class(**args.frontend_conf)
  1076. input_size = frontend.output_size()
  1077. else:
  1078. # Give features from data-loader
  1079. args.frontend = None
  1080. args.frontend_conf = {}
  1081. frontend = None
  1082. input_size = args.input_size
  1083. # 2. Data augmentation for spectrogram
  1084. if args.specaug is not None:
  1085. specaug_class = specaug_choices.get_class(args.specaug)
  1086. specaug = specaug_class(**args.specaug_conf)
  1087. else:
  1088. specaug = None
  1089. # 3. Normalization layer
  1090. if args.normalize is not None:
  1091. normalize_class = normalize_choices.get_class(args.normalize)
  1092. normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
  1093. else:
  1094. normalize = None
  1095. # 4. Pre-encoder input block
  1096. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  1097. if getattr(args, "preencoder", None) is not None:
  1098. preencoder_class = preencoder_choices.get_class(args.preencoder)
  1099. preencoder = preencoder_class(**args.preencoder_conf)
  1100. input_size = preencoder.output_size()
  1101. else:
  1102. preencoder = None
  1103. # 5. Encoder
  1104. encoder_class = encoder_choices.get_class(args.encoder)
  1105. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1106. # 7. Decoder
  1107. decoder_class = decoder_choices.get_class(args.decoder)
  1108. decoder = decoder_class(
  1109. vocab_size=vocab_size,
  1110. encoder_output_size=encoder.output_size(),
  1111. **args.decoder_conf,
  1112. )
  1113. # 8. CTC
  1114. ctc = CTC(
  1115. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  1116. )
  1117. # 10. Build model
  1118. try:
  1119. model_class = model_choices.get_class(args.model)
  1120. except AttributeError:
  1121. model_class = model_choices.get_class("asr")
  1122. rnnt_decoder = None
  1123. # 8. Build model
  1124. model = model_class(
  1125. vocab_size=vocab_size,
  1126. frontend=frontend,
  1127. specaug=specaug,
  1128. normalize=normalize,
  1129. preencoder=preencoder,
  1130. encoder=encoder,
  1131. decoder=decoder,
  1132. ctc=ctc,
  1133. rnnt_decoder=rnnt_decoder,
  1134. token_list=token_list,
  1135. **args.model_conf,
  1136. )
  1137. # 11. Initialize
  1138. if args.init is not None:
  1139. initialize(model, args.init)
  1140. assert check_return_type(model)
  1141. return model
  1142. class ASRTaskAligner(ASRTaskParaformer):
  1143. # If you need more than one optimizers, change this value
  1144. num_optimizers: int = 1
  1145. # Add variable objects configurations
  1146. class_choices_list = [
  1147. # --frontend and --frontend_conf
  1148. frontend_choices,
  1149. # --model and --model_conf
  1150. model_choices,
  1151. # --encoder and --encoder_conf
  1152. encoder_choices,
  1153. # --decoder and --decoder_conf
  1154. decoder_choices,
  1155. ]
  1156. # If you need to modify train() or eval() procedures, change Trainer class here
  1157. trainer = Trainer
  1158. @classmethod
  1159. def build_model(cls, args: argparse.Namespace):
  1160. assert check_argument_types()
  1161. if isinstance(args.token_list, str):
  1162. with open(args.token_list, encoding="utf-8") as f:
  1163. token_list = [line.rstrip() for line in f]
  1164. # Overwriting token_list to keep it as "portable".
  1165. args.token_list = list(token_list)
  1166. elif isinstance(args.token_list, (tuple, list)):
  1167. token_list = list(args.token_list)
  1168. else:
  1169. raise RuntimeError("token_list must be str or list")
  1170. # 1. frontend
  1171. if args.input_size is None:
  1172. # Extract features in the model
  1173. frontend_class = frontend_choices.get_class(args.frontend)
  1174. if args.frontend == 'wav_frontend':
  1175. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1176. else:
  1177. frontend = frontend_class(**args.frontend_conf)
  1178. input_size = frontend.output_size()
  1179. else:
  1180. # Give features from data-loader
  1181. args.frontend = None
  1182. args.frontend_conf = {}
  1183. frontend = None
  1184. input_size = args.input_size
  1185. # 2. Encoder
  1186. encoder_class = encoder_choices.get_class(args.encoder)
  1187. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1188. # 3. Predictor
  1189. predictor_class = predictor_choices.get_class(args.predictor)
  1190. predictor = predictor_class(**args.predictor_conf)
  1191. # 10. Build model
  1192. try:
  1193. model_class = model_choices.get_class(args.model)
  1194. except AttributeError:
  1195. model_class = model_choices.get_class("asr")
  1196. # 8. Build model
  1197. model = model_class(
  1198. frontend=frontend,
  1199. encoder=encoder,
  1200. predictor=predictor,
  1201. token_list=token_list,
  1202. **args.model_conf,
  1203. )
  1204. # 11. Initialize
  1205. if args.init is not None:
  1206. initialize(model, args.init)
  1207. assert check_return_type(model)
  1208. return model
  1209. @classmethod
  1210. def required_data_names(
  1211. cls, train: bool = True, inference: bool = False
  1212. ) -> Tuple[str, ...]:
  1213. retval = ("speech", "text")
  1214. return retval