asr.py 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.ctc import CTC
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.decoder.rnn_decoder import RNNDecoder
  25. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  26. from funasr.models.decoder.transformer_decoder import (
  27. DynamicConvolution2DTransformerDecoder, # noqa: H301
  28. )
  29. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolution2DTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import (
  34. LightweightConvolutionTransformerDecoder, # noqa: H301
  35. )
  36. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  37. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  38. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  39. from funasr.models.e2e_asr import ASRModel
  40. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  41. from funasr.models.joint_net.joint_network import JointNetwork
  42. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  43. from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
  44. from funasr.models.e2e_tp import TimestampPredictor
  45. from funasr.models.e2e_asr_mfcca import MFCCA
  46. from funasr.models.e2e_uni_asr import UniASR
  47. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  48. from funasr.models.encoder.abs_encoder import AbsEncoder
  49. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  50. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  51. from funasr.models.encoder.rnn_encoder import RNNEncoder
  52. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  53. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  54. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  55. from funasr.models.frontend.abs_frontend import AbsFrontend
  56. from funasr.models.frontend.default import DefaultFrontend
  57. from funasr.models.frontend.default import MultiChannelFrontend
  58. from funasr.models.frontend.fused import FusedFrontends
  59. from funasr.models.frontend.s3prl import S3prlFrontend
  60. from funasr.models.frontend.wav_frontend import WavFrontend
  61. from funasr.models.frontend.windowing import SlidingWindow
  62. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  63. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  64. HuggingFaceTransformersPostEncoder, # noqa: H301
  65. )
  66. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  67. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  68. from funasr.models.preencoder.linear import LinearProjection
  69. from funasr.models.preencoder.sinc import LightweightSincConvs
  70. from funasr.models.specaug.abs_specaug import AbsSpecAug
  71. from funasr.models.specaug.specaug import SpecAug
  72. from funasr.models.specaug.specaug import SpecAugLFR
  73. from funasr.modules.subsampling import Conv1dSubsampling
  74. from funasr.tasks.abs_task import AbsTask
  75. from funasr.text.phoneme_tokenizer import g2p_choices
  76. from funasr.torch_utils.initialize import initialize
  77. from funasr.models.base_model import FunASRModel
  78. from funasr.train.class_choices import ClassChoices
  79. from funasr.train.trainer import Trainer
  80. from funasr.utils.get_default_kwargs import get_default_kwargs
  81. from funasr.utils.nested_dict_action import NestedDictAction
  82. from funasr.utils.types import float_or_none
  83. from funasr.utils.types import int_or_none
  84. from funasr.utils.types import str2bool
  85. from funasr.utils.types import str_or_none
  86. frontend_choices = ClassChoices(
  87. name="frontend",
  88. classes=dict(
  89. default=DefaultFrontend,
  90. sliding_window=SlidingWindow,
  91. s3prl=S3prlFrontend,
  92. fused=FusedFrontends,
  93. wav_frontend=WavFrontend,
  94. multichannelfrontend=MultiChannelFrontend,
  95. ),
  96. type_check=AbsFrontend,
  97. default="default",
  98. )
  99. specaug_choices = ClassChoices(
  100. name="specaug",
  101. classes=dict(
  102. specaug=SpecAug,
  103. specaug_lfr=SpecAugLFR,
  104. ),
  105. type_check=AbsSpecAug,
  106. default=None,
  107. optional=True,
  108. )
  109. normalize_choices = ClassChoices(
  110. "normalize",
  111. classes=dict(
  112. global_mvn=GlobalMVN,
  113. utterance_mvn=UtteranceMVN,
  114. ),
  115. type_check=AbsNormalize,
  116. default=None,
  117. optional=True,
  118. )
  119. model_choices = ClassChoices(
  120. "model",
  121. classes=dict(
  122. asr=ASRModel,
  123. uniasr=UniASR,
  124. paraformer=Paraformer,
  125. paraformer_online=ParaformerOnline,
  126. paraformer_bert=ParaformerBert,
  127. bicif_paraformer=BiCifParaformer,
  128. contextual_paraformer=ContextualParaformer,
  129. neatcontextual_paraformer=NeatContextualParaformer,
  130. mfcca=MFCCA,
  131. timestamp_prediction=TimestampPredictor,
  132. rnnt=TransducerModel,
  133. rnnt_unified=UnifiedTransducerModel,
  134. ),
  135. type_check=FunASRModel,
  136. default="asr",
  137. )
  138. preencoder_choices = ClassChoices(
  139. name="preencoder",
  140. classes=dict(
  141. sinc=LightweightSincConvs,
  142. linear=LinearProjection,
  143. ),
  144. type_check=AbsPreEncoder,
  145. default=None,
  146. optional=True,
  147. )
  148. encoder_choices = ClassChoices(
  149. "encoder",
  150. classes=dict(
  151. conformer=ConformerEncoder,
  152. transformer=TransformerEncoder,
  153. rnn=RNNEncoder,
  154. sanm=SANMEncoder,
  155. sanm_chunk_opt=SANMEncoderChunkOpt,
  156. data2vec_encoder=Data2VecEncoder,
  157. mfcca_enc=MFCCAEncoder,
  158. chunk_conformer=ConformerChunkEncoder,
  159. ),
  160. type_check=AbsEncoder,
  161. default="rnn",
  162. )
  163. encoder_choices2 = ClassChoices(
  164. "encoder2",
  165. classes=dict(
  166. conformer=ConformerEncoder,
  167. transformer=TransformerEncoder,
  168. rnn=RNNEncoder,
  169. sanm=SANMEncoder,
  170. sanm_chunk_opt=SANMEncoderChunkOpt,
  171. ),
  172. type_check=AbsEncoder,
  173. default="rnn",
  174. )
  175. postencoder_choices = ClassChoices(
  176. name="postencoder",
  177. classes=dict(
  178. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  179. ),
  180. type_check=AbsPostEncoder,
  181. default=None,
  182. optional=True,
  183. )
  184. decoder_choices = ClassChoices(
  185. "decoder",
  186. classes=dict(
  187. transformer=TransformerDecoder,
  188. lightweight_conv=LightweightConvolutionTransformerDecoder,
  189. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  190. dynamic_conv=DynamicConvolutionTransformerDecoder,
  191. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  192. rnn=RNNDecoder,
  193. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  194. paraformer_decoder_sanm=ParaformerSANMDecoder,
  195. paraformer_decoder_san=ParaformerDecoderSAN,
  196. contextual_paraformer_decoder=ContextualParaformerDecoder,
  197. ),
  198. type_check=AbsDecoder,
  199. default="rnn",
  200. )
  201. decoder_choices2 = ClassChoices(
  202. "decoder2",
  203. classes=dict(
  204. transformer=TransformerDecoder,
  205. lightweight_conv=LightweightConvolutionTransformerDecoder,
  206. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  207. dynamic_conv=DynamicConvolutionTransformerDecoder,
  208. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  209. rnn=RNNDecoder,
  210. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  211. paraformer_decoder_sanm=ParaformerSANMDecoder,
  212. ),
  213. type_check=AbsDecoder,
  214. default="rnn",
  215. )
  216. rnnt_decoder_choices = ClassChoices(
  217. "rnnt_decoder",
  218. classes=dict(
  219. rnnt=RNNTDecoder,
  220. ),
  221. type_check=RNNTDecoder,
  222. default="rnnt",
  223. )
  224. joint_network_choices = ClassChoices(
  225. name="joint_network",
  226. classes=dict(
  227. joint_network=JointNetwork,
  228. ),
  229. default="joint_network",
  230. optional=True,
  231. )
  232. predictor_choices = ClassChoices(
  233. name="predictor",
  234. classes=dict(
  235. cif_predictor=CifPredictor,
  236. ctc_predictor=None,
  237. cif_predictor_v2=CifPredictorV2,
  238. cif_predictor_v3=CifPredictorV3,
  239. ),
  240. type_check=None,
  241. default="cif_predictor",
  242. optional=True,
  243. )
  244. predictor_choices2 = ClassChoices(
  245. name="predictor2",
  246. classes=dict(
  247. cif_predictor=CifPredictor,
  248. ctc_predictor=None,
  249. cif_predictor_v2=CifPredictorV2,
  250. ),
  251. type_check=None,
  252. default="cif_predictor",
  253. optional=True,
  254. )
  255. stride_conv_choices = ClassChoices(
  256. name="stride_conv",
  257. classes=dict(
  258. stride_conv1d=Conv1dSubsampling
  259. ),
  260. type_check=None,
  261. default="stride_conv1d",
  262. optional=True,
  263. )
  264. class ASRTask(AbsTask):
  265. # If you need more than one optimizers, change this value
  266. num_optimizers: int = 1
  267. # Add variable objects configurations
  268. class_choices_list = [
  269. # --frontend and --frontend_conf
  270. frontend_choices,
  271. # --specaug and --specaug_conf
  272. specaug_choices,
  273. # --normalize and --normalize_conf
  274. normalize_choices,
  275. # --model and --model_conf
  276. model_choices,
  277. # --preencoder and --preencoder_conf
  278. preencoder_choices,
  279. # --encoder and --encoder_conf
  280. encoder_choices,
  281. # --postencoder and --postencoder_conf
  282. postencoder_choices,
  283. # --decoder and --decoder_conf
  284. decoder_choices,
  285. # --predictor and --predictor_conf
  286. predictor_choices,
  287. # --encoder2 and --encoder2_conf
  288. encoder_choices2,
  289. # --decoder2 and --decoder2_conf
  290. decoder_choices2,
  291. # --predictor2 and --predictor2_conf
  292. predictor_choices2,
  293. # --stride_conv and --stride_conv_conf
  294. stride_conv_choices,
  295. # --rnnt_decoder and --rnnt_decoder_conf
  296. rnnt_decoder_choices,
  297. ]
  298. # If you need to modify train() or eval() procedures, change Trainer class here
  299. trainer = Trainer
  300. @classmethod
  301. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  302. group = parser.add_argument_group(description="Task related")
  303. # NOTE(kamo): add_arguments(..., required=True) can't be used
  304. # to provide --print_config mode. Instead of it, do as
  305. # required = parser.get_default("required")
  306. # required += ["token_list"]
  307. group.add_argument(
  308. "--token_list",
  309. type=str_or_none,
  310. default=None,
  311. help="A text mapping int-id to token",
  312. )
  313. group.add_argument(
  314. "--split_with_space",
  315. type=str2bool,
  316. default=True,
  317. help="whether to split text using <space>",
  318. )
  319. group.add_argument(
  320. "--seg_dict_file",
  321. type=str,
  322. default=None,
  323. help="seg_dict_file for text processing",
  324. )
  325. group.add_argument(
  326. "--init",
  327. type=lambda x: str_or_none(x.lower()),
  328. default=None,
  329. help="The initialization method",
  330. choices=[
  331. "chainer",
  332. "xavier_uniform",
  333. "xavier_normal",
  334. "kaiming_uniform",
  335. "kaiming_normal",
  336. None,
  337. ],
  338. )
  339. group.add_argument(
  340. "--input_size",
  341. type=int_or_none,
  342. default=None,
  343. help="The number of input dimension of the feature",
  344. )
  345. group.add_argument(
  346. "--ctc_conf",
  347. action=NestedDictAction,
  348. default=get_default_kwargs(CTC),
  349. help="The keyword arguments for CTC class.",
  350. )
  351. group = parser.add_argument_group(description="Preprocess related")
  352. group.add_argument(
  353. "--use_preprocessor",
  354. type=str2bool,
  355. default=True,
  356. help="Apply preprocessing to data or not",
  357. )
  358. group.add_argument(
  359. "--token_type",
  360. type=str,
  361. default="bpe",
  362. choices=["bpe", "char", "word", "phn"],
  363. help="The text will be tokenized " "in the specified level token",
  364. )
  365. group.add_argument(
  366. "--bpemodel",
  367. type=str_or_none,
  368. default=None,
  369. help="The model file of sentencepiece",
  370. )
  371. parser.add_argument(
  372. "--non_linguistic_symbols",
  373. type=str_or_none,
  374. default=None,
  375. help="non_linguistic_symbols file path",
  376. )
  377. parser.add_argument(
  378. "--cleaner",
  379. type=str_or_none,
  380. choices=[None, "tacotron", "jaconv", "vietnamese"],
  381. default=None,
  382. help="Apply text cleaning",
  383. )
  384. parser.add_argument(
  385. "--g2p",
  386. type=str_or_none,
  387. choices=g2p_choices,
  388. default=None,
  389. help="Specify g2p method if --token_type=phn",
  390. )
  391. parser.add_argument(
  392. "--speech_volume_normalize",
  393. type=float_or_none,
  394. default=None,
  395. help="Scale the maximum amplitude to the given value.",
  396. )
  397. parser.add_argument(
  398. "--rir_scp",
  399. type=str_or_none,
  400. default=None,
  401. help="The file path of rir scp file.",
  402. )
  403. parser.add_argument(
  404. "--rir_apply_prob",
  405. type=float,
  406. default=1.0,
  407. help="THe probability for applying RIR convolution.",
  408. )
  409. parser.add_argument(
  410. "--cmvn_file",
  411. type=str_or_none,
  412. default=None,
  413. help="The file path of noise scp file.",
  414. )
  415. parser.add_argument(
  416. "--noise_scp",
  417. type=str_or_none,
  418. default=None,
  419. help="The file path of noise scp file.",
  420. )
  421. parser.add_argument(
  422. "--noise_apply_prob",
  423. type=float,
  424. default=1.0,
  425. help="The probability applying Noise adding.",
  426. )
  427. parser.add_argument(
  428. "--noise_db_range",
  429. type=str,
  430. default="13_15",
  431. help="The range of noise decibel level.",
  432. )
  433. for class_choices in cls.class_choices_list:
  434. # Append --<name> and --<name>_conf.
  435. # e.g. --encoder and --encoder_conf
  436. class_choices.add_arguments(group)
  437. @classmethod
  438. def build_collate_fn(
  439. cls, args: argparse.Namespace, train: bool
  440. ) -> Callable[
  441. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  442. Tuple[List[str], Dict[str, torch.Tensor]],
  443. ]:
  444. assert check_argument_types()
  445. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  446. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  447. @classmethod
  448. def build_preprocess_fn(
  449. cls, args: argparse.Namespace, train: bool
  450. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  451. assert check_argument_types()
  452. if args.use_preprocessor:
  453. retval = CommonPreprocessor(
  454. train=train,
  455. token_type=args.token_type,
  456. token_list=args.token_list,
  457. bpemodel=args.bpemodel,
  458. non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
  459. text_cleaner=args.cleaner,
  460. g2p_type=args.g2p,
  461. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  462. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  463. # NOTE(kamo): Check attribute existence for backward compatibility
  464. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  465. rir_apply_prob=args.rir_apply_prob
  466. if hasattr(args, "rir_apply_prob")
  467. else 1.0,
  468. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  469. noise_apply_prob=args.noise_apply_prob
  470. if hasattr(args, "noise_apply_prob")
  471. else 1.0,
  472. noise_db_range=args.noise_db_range
  473. if hasattr(args, "noise_db_range")
  474. else "13_15",
  475. speech_volume_normalize=args.speech_volume_normalize
  476. if hasattr(args, "rir_scp")
  477. else None,
  478. )
  479. else:
  480. retval = None
  481. assert check_return_type(retval)
  482. return retval
  483. @classmethod
  484. def required_data_names(
  485. cls, train: bool = True, inference: bool = False
  486. ) -> Tuple[str, ...]:
  487. if not inference:
  488. retval = ("speech", "text")
  489. else:
  490. # Recognition mode
  491. retval = ("speech",)
  492. return retval
  493. @classmethod
  494. def optional_data_names(
  495. cls, train: bool = True, inference: bool = False
  496. ) -> Tuple[str, ...]:
  497. retval = ()
  498. assert check_return_type(retval)
  499. return retval
  500. @classmethod
  501. def build_model(cls, args: argparse.Namespace):
  502. assert check_argument_types()
  503. if isinstance(args.token_list, str):
  504. with open(args.token_list, encoding="utf-8") as f:
  505. token_list = [line.rstrip() for line in f]
  506. # Overwriting token_list to keep it as "portable".
  507. args.token_list = list(token_list)
  508. elif isinstance(args.token_list, (tuple, list)):
  509. token_list = list(args.token_list)
  510. else:
  511. raise RuntimeError("token_list must be str or list")
  512. vocab_size = len(token_list)
  513. logging.info(f"Vocabulary size: {vocab_size}")
  514. # 1. frontend
  515. if args.input_size is None:
  516. # Extract features in the model
  517. frontend_class = frontend_choices.get_class(args.frontend)
  518. if args.frontend == 'wav_frontend':
  519. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  520. else:
  521. frontend = frontend_class(**args.frontend_conf)
  522. input_size = frontend.output_size()
  523. else:
  524. # Give features from data-loader
  525. args.frontend = None
  526. args.frontend_conf = {}
  527. frontend = None
  528. input_size = args.input_size
  529. # 2. Data augmentation for spectrogram
  530. if args.specaug is not None:
  531. specaug_class = specaug_choices.get_class(args.specaug)
  532. specaug = specaug_class(**args.specaug_conf)
  533. else:
  534. specaug = None
  535. # 3. Normalization layer
  536. if args.normalize is not None:
  537. normalize_class = normalize_choices.get_class(args.normalize)
  538. normalize = normalize_class(**args.normalize_conf)
  539. else:
  540. normalize = None
  541. # 4. Pre-encoder input block
  542. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  543. if getattr(args, "preencoder", None) is not None:
  544. preencoder_class = preencoder_choices.get_class(args.preencoder)
  545. preencoder = preencoder_class(**args.preencoder_conf)
  546. input_size = preencoder.output_size()
  547. else:
  548. preencoder = None
  549. # 5. Encoder
  550. encoder_class = encoder_choices.get_class(args.encoder)
  551. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  552. # 6. Post-encoder block
  553. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  554. encoder_output_size = encoder.output_size()
  555. if getattr(args, "postencoder", None) is not None:
  556. postencoder_class = postencoder_choices.get_class(args.postencoder)
  557. postencoder = postencoder_class(
  558. input_size=encoder_output_size, **args.postencoder_conf
  559. )
  560. encoder_output_size = postencoder.output_size()
  561. else:
  562. postencoder = None
  563. # 7. Decoder
  564. decoder_class = decoder_choices.get_class(args.decoder)
  565. decoder = decoder_class(
  566. vocab_size=vocab_size,
  567. encoder_output_size=encoder_output_size,
  568. **args.decoder_conf,
  569. )
  570. # 8. CTC
  571. ctc = CTC(
  572. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  573. )
  574. # 9. Build model
  575. try:
  576. model_class = model_choices.get_class(args.model)
  577. except AttributeError:
  578. model_class = model_choices.get_class("asr")
  579. model = model_class(
  580. vocab_size=vocab_size,
  581. frontend=frontend,
  582. specaug=specaug,
  583. normalize=normalize,
  584. preencoder=preencoder,
  585. encoder=encoder,
  586. postencoder=postencoder,
  587. decoder=decoder,
  588. ctc=ctc,
  589. token_list=token_list,
  590. **args.model_conf,
  591. )
  592. # 10. Initialize
  593. if args.init is not None:
  594. initialize(model, args.init)
  595. assert check_return_type(model)
  596. return model
  597. class ASRTaskUniASR(ASRTask):
  598. # If you need more than one optimizers, change this value
  599. num_optimizers: int = 1
  600. # Add variable objects configurations
  601. class_choices_list = [
  602. # --frontend and --frontend_conf
  603. frontend_choices,
  604. # --specaug and --specaug_conf
  605. specaug_choices,
  606. # --normalize and --normalize_conf
  607. normalize_choices,
  608. # --model and --model_conf
  609. model_choices,
  610. # --preencoder and --preencoder_conf
  611. preencoder_choices,
  612. # --encoder and --encoder_conf
  613. encoder_choices,
  614. # --postencoder and --postencoder_conf
  615. postencoder_choices,
  616. # --decoder and --decoder_conf
  617. decoder_choices,
  618. # --predictor and --predictor_conf
  619. predictor_choices,
  620. # --encoder2 and --encoder2_conf
  621. encoder_choices2,
  622. # --decoder2 and --decoder2_conf
  623. decoder_choices2,
  624. # --predictor2 and --predictor2_conf
  625. predictor_choices2,
  626. # --stride_conv and --stride_conv_conf
  627. stride_conv_choices,
  628. ]
  629. # If you need to modify train() or eval() procedures, change Trainer class here
  630. trainer = Trainer
  631. @classmethod
  632. def build_model(cls, args: argparse.Namespace):
  633. assert check_argument_types()
  634. if isinstance(args.token_list, str):
  635. with open(args.token_list, encoding="utf-8") as f:
  636. token_list = [line.rstrip() for line in f]
  637. # Overwriting token_list to keep it as "portable".
  638. args.token_list = list(token_list)
  639. elif isinstance(args.token_list, (tuple, list)):
  640. token_list = list(args.token_list)
  641. else:
  642. raise RuntimeError("token_list must be str or list")
  643. vocab_size = len(token_list)
  644. logging.info(f"Vocabulary size: {vocab_size}")
  645. # 1. frontend
  646. if args.input_size is None:
  647. # Extract features in the model
  648. frontend_class = frontend_choices.get_class(args.frontend)
  649. if args.frontend == 'wav_frontend':
  650. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  651. else:
  652. frontend = frontend_class(**args.frontend_conf)
  653. input_size = frontend.output_size()
  654. else:
  655. # Give features from data-loader
  656. args.frontend = None
  657. args.frontend_conf = {}
  658. frontend = None
  659. input_size = args.input_size
  660. # 2. Data augmentation for spectrogram
  661. if args.specaug is not None:
  662. specaug_class = specaug_choices.get_class(args.specaug)
  663. specaug = specaug_class(**args.specaug_conf)
  664. else:
  665. specaug = None
  666. # 3. Normalization layer
  667. if args.normalize is not None:
  668. normalize_class = normalize_choices.get_class(args.normalize)
  669. normalize = normalize_class(**args.normalize_conf)
  670. else:
  671. normalize = None
  672. # 4. Pre-encoder input block
  673. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  674. if getattr(args, "preencoder", None) is not None:
  675. preencoder_class = preencoder_choices.get_class(args.preencoder)
  676. preencoder = preencoder_class(**args.preencoder_conf)
  677. input_size = preencoder.output_size()
  678. else:
  679. preencoder = None
  680. # 5. Encoder
  681. encoder_class = encoder_choices.get_class(args.encoder)
  682. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  683. encoder_output_size = encoder.output_size()
  684. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  685. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder_output_size,
  686. odim=input_size + encoder_output_size)
  687. stride_conv_output_size = stride_conv.output_size()
  688. # 6. Encoder2
  689. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  690. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  691. # 7. Post-encoder block
  692. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  693. encoder_output_size2 = encoder2.output_size()
  694. if getattr(args, "postencoder", None) is not None:
  695. postencoder_class = postencoder_choices.get_class(args.postencoder)
  696. postencoder = postencoder_class(
  697. input_size=encoder_output_size, **args.postencoder_conf
  698. )
  699. encoder_output_size = postencoder.output_size()
  700. else:
  701. postencoder = None
  702. # 8. Decoder & Decoder2
  703. decoder_class = decoder_choices.get_class(args.decoder)
  704. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  705. decoder = decoder_class(
  706. vocab_size=vocab_size,
  707. encoder_output_size=encoder_output_size,
  708. **args.decoder_conf,
  709. )
  710. decoder2 = decoder_class2(
  711. vocab_size=vocab_size,
  712. encoder_output_size=encoder_output_size2,
  713. **args.decoder2_conf,
  714. )
  715. # 9. CTC
  716. ctc = CTC(
  717. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  718. )
  719. ctc2 = CTC(
  720. odim=vocab_size, encoder_output_size=encoder_output_size2, **args.ctc_conf
  721. )
  722. # 10. Predictor
  723. predictor_class = predictor_choices.get_class(args.predictor)
  724. predictor = predictor_class(**args.predictor_conf)
  725. predictor_class = predictor_choices2.get_class(args.predictor2)
  726. predictor2 = predictor_class(**args.predictor2_conf)
  727. # 11. Build model
  728. try:
  729. model_class = model_choices.get_class(args.model)
  730. except AttributeError:
  731. model_class = model_choices.get_class("asr")
  732. model = model_class(
  733. vocab_size=vocab_size,
  734. frontend=frontend,
  735. specaug=specaug,
  736. normalize=normalize,
  737. preencoder=preencoder,
  738. encoder=encoder,
  739. postencoder=postencoder,
  740. decoder=decoder,
  741. ctc=ctc,
  742. token_list=token_list,
  743. predictor=predictor,
  744. ctc2=ctc2,
  745. encoder2=encoder2,
  746. decoder2=decoder2,
  747. predictor2=predictor2,
  748. stride_conv=stride_conv,
  749. **args.model_conf,
  750. )
  751. # 12. Initialize
  752. if args.init is not None:
  753. initialize(model, args.init)
  754. assert check_return_type(model)
  755. return model
  756. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  757. @classmethod
  758. def build_model_from_file(
  759. cls,
  760. config_file: Union[Path, str] = None,
  761. model_file: Union[Path, str] = None,
  762. cmvn_file: Union[Path, str] = None,
  763. device: str = "cpu",
  764. ):
  765. """Build model from the files.
  766. This method is used for inference or fine-tuning.
  767. Args:
  768. config_file: The yaml file saved when training.
  769. model_file: The model file saved when training.
  770. device: Device type, "cpu", "cuda", or "cuda:N".
  771. """
  772. assert check_argument_types()
  773. if config_file is None:
  774. assert model_file is not None, (
  775. "The argument 'model_file' must be provided "
  776. "if the argument 'config_file' is not specified."
  777. )
  778. config_file = Path(model_file).parent / "config.yaml"
  779. else:
  780. config_file = Path(config_file)
  781. with config_file.open("r", encoding="utf-8") as f:
  782. args = yaml.safe_load(f)
  783. if cmvn_file is not None:
  784. args["cmvn_file"] = cmvn_file
  785. args = argparse.Namespace(**args)
  786. model = cls.build_model(args)
  787. if not isinstance(model, FunASRModel):
  788. raise RuntimeError(
  789. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  790. )
  791. model.to(device)
  792. model_dict = dict()
  793. model_name_pth = None
  794. if model_file is not None:
  795. logging.info("model_file is {}".format(model_file))
  796. if device == "cuda":
  797. device = f"cuda:{torch.cuda.current_device()}"
  798. model_dir = os.path.dirname(model_file)
  799. model_name = os.path.basename(model_file)
  800. if "model.ckpt-" in model_name or ".bin" in model_name:
  801. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  802. '.pb')) if ".bin" in model_name else os.path.join(
  803. model_dir, "{}.pb".format(model_name))
  804. if os.path.exists(model_name_pth):
  805. logging.info("model_file is load from pth: {}".format(model_name_pth))
  806. model_dict = torch.load(model_name_pth, map_location=device)
  807. else:
  808. model_dict = cls.convert_tf2torch(model, model_file)
  809. model.load_state_dict(model_dict)
  810. else:
  811. model_dict = torch.load(model_file, map_location=device)
  812. model.load_state_dict(model_dict)
  813. if model_name_pth is not None and not os.path.exists(model_name_pth):
  814. torch.save(model_dict, model_name_pth)
  815. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  816. return model, args
  817. @classmethod
  818. def convert_tf2torch(
  819. cls,
  820. model,
  821. ckpt,
  822. ):
  823. logging.info("start convert tf model to torch model")
  824. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  825. var_dict_tf = load_tf_dict(ckpt)
  826. var_dict_torch = model.state_dict()
  827. var_dict_torch_update = dict()
  828. # encoder
  829. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  830. var_dict_torch_update.update(var_dict_torch_update_local)
  831. # predictor
  832. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  833. var_dict_torch_update.update(var_dict_torch_update_local)
  834. # decoder
  835. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  836. var_dict_torch_update.update(var_dict_torch_update_local)
  837. # encoder2
  838. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  839. var_dict_torch_update.update(var_dict_torch_update_local)
  840. # predictor2
  841. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  842. var_dict_torch_update.update(var_dict_torch_update_local)
  843. # decoder2
  844. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  845. var_dict_torch_update.update(var_dict_torch_update_local)
  846. # stride_conv
  847. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  848. var_dict_torch_update.update(var_dict_torch_update_local)
  849. return var_dict_torch_update
  850. class ASRTaskParaformer(ASRTask):
  851. # If you need more than one optimizers, change this value
  852. num_optimizers: int = 1
  853. # # Add variable objects configurations
  854. # class_choices_list = [
  855. # # --frontend and --frontend_conf
  856. # frontend_choices,
  857. # # --specaug and --specaug_conf
  858. # specaug_choices,
  859. # # --normalize and --normalize_conf
  860. # normalize_choices,
  861. # # --model and --model_conf
  862. # model_choices,
  863. # # --preencoder and --preencoder_conf
  864. # preencoder_choices,
  865. # # --encoder and --encoder_conf
  866. # encoder_choices,
  867. # # --postencoder and --postencoder_conf
  868. # postencoder_choices,
  869. # # --decoder and --decoder_conf
  870. # decoder_choices,
  871. # # --predictor and --predictor_conf
  872. # predictor_choices,
  873. # ]
  874. # If you need to modify train() or eval() procedures, change Trainer class here
  875. trainer = Trainer
  876. @classmethod
  877. def build_model(cls, args: argparse.Namespace):
  878. assert check_argument_types()
  879. if isinstance(args.token_list, str):
  880. with open(args.token_list, encoding="utf-8") as f:
  881. token_list = [line.rstrip() for line in f]
  882. # Overwriting token_list to keep it as "portable".
  883. args.token_list = list(token_list)
  884. elif isinstance(args.token_list, (tuple, list)):
  885. token_list = list(args.token_list)
  886. else:
  887. raise RuntimeError("token_list must be str or list")
  888. vocab_size = len(token_list)
  889. logging.info(f"Vocabulary size: {vocab_size}")
  890. # 1. frontend
  891. if args.input_size is None:
  892. # Extract features in the model
  893. frontend_class = frontend_choices.get_class(args.frontend)
  894. if args.frontend == 'wav_frontend':
  895. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  896. else:
  897. frontend = frontend_class(**args.frontend_conf)
  898. input_size = frontend.output_size()
  899. else:
  900. # Give features from data-loader
  901. args.frontend = None
  902. args.frontend_conf = {}
  903. frontend = None
  904. input_size = args.input_size
  905. # 2. Data augmentation for spectrogram
  906. if args.specaug is not None:
  907. specaug_class = specaug_choices.get_class(args.specaug)
  908. specaug = specaug_class(**args.specaug_conf)
  909. else:
  910. specaug = None
  911. # 3. Normalization layer
  912. if args.normalize is not None:
  913. normalize_class = normalize_choices.get_class(args.normalize)
  914. normalize = normalize_class(**args.normalize_conf)
  915. else:
  916. normalize = None
  917. # 4. Pre-encoder input block
  918. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  919. if getattr(args, "preencoder", None) is not None:
  920. preencoder_class = preencoder_choices.get_class(args.preencoder)
  921. preencoder = preencoder_class(**args.preencoder_conf)
  922. input_size = preencoder.output_size()
  923. else:
  924. preencoder = None
  925. # 5. Encoder
  926. encoder_class = encoder_choices.get_class(args.encoder)
  927. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  928. # 6. Post-encoder block
  929. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  930. encoder_output_size = encoder.output_size()
  931. if getattr(args, "postencoder", None) is not None:
  932. postencoder_class = postencoder_choices.get_class(args.postencoder)
  933. postencoder = postencoder_class(
  934. input_size=encoder_output_size, **args.postencoder_conf
  935. )
  936. encoder_output_size = postencoder.output_size()
  937. else:
  938. postencoder = None
  939. # 7. Decoder
  940. decoder_class = decoder_choices.get_class(args.decoder)
  941. decoder = decoder_class(
  942. vocab_size=vocab_size,
  943. encoder_output_size=encoder_output_size,
  944. **args.decoder_conf,
  945. )
  946. # 8. CTC
  947. ctc = CTC(
  948. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  949. )
  950. # 9. Predictor
  951. predictor_class = predictor_choices.get_class(args.predictor)
  952. predictor = predictor_class(**args.predictor_conf)
  953. # 10. Build model
  954. try:
  955. model_class = model_choices.get_class(args.model)
  956. except AttributeError:
  957. model_class = model_choices.get_class("asr")
  958. model = model_class(
  959. vocab_size=vocab_size,
  960. frontend=frontend,
  961. specaug=specaug,
  962. normalize=normalize,
  963. preencoder=preencoder,
  964. encoder=encoder,
  965. postencoder=postencoder,
  966. decoder=decoder,
  967. ctc=ctc,
  968. token_list=token_list,
  969. predictor=predictor,
  970. **args.model_conf,
  971. )
  972. # 11. Initialize
  973. if args.init is not None:
  974. initialize(model, args.init)
  975. assert check_return_type(model)
  976. return model
  977. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  978. @classmethod
  979. def build_model_from_file(
  980. cls,
  981. config_file: Union[Path, str] = None,
  982. model_file: Union[Path, str] = None,
  983. cmvn_file: Union[Path, str] = None,
  984. device: str = "cpu",
  985. ):
  986. """Build model from the files.
  987. This method is used for inference or fine-tuning.
  988. Args:
  989. config_file: The yaml file saved when training.
  990. model_file: The model file saved when training.
  991. device: Device type, "cpu", "cuda", or "cuda:N".
  992. """
  993. assert check_argument_types()
  994. if config_file is None:
  995. assert model_file is not None, (
  996. "The argument 'model_file' must be provided "
  997. "if the argument 'config_file' is not specified."
  998. )
  999. config_file = Path(model_file).parent / "config.yaml"
  1000. else:
  1001. config_file = Path(config_file)
  1002. with config_file.open("r", encoding="utf-8") as f:
  1003. args = yaml.safe_load(f)
  1004. if cmvn_file is not None:
  1005. args["cmvn_file"] = cmvn_file
  1006. args = argparse.Namespace(**args)
  1007. model = cls.build_model(args)
  1008. if not isinstance(model, FunASRModel):
  1009. raise RuntimeError(
  1010. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  1011. )
  1012. model.to(device)
  1013. model_dict = dict()
  1014. model_name_pth = None
  1015. if model_file is not None:
  1016. logging.info("model_file is {}".format(model_file))
  1017. if device == "cuda":
  1018. device = f"cuda:{torch.cuda.current_device()}"
  1019. model_dir = os.path.dirname(model_file)
  1020. model_name = os.path.basename(model_file)
  1021. if "model.ckpt-" in model_name or ".bin" in model_name:
  1022. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  1023. '.pb')) if ".bin" in model_name else os.path.join(
  1024. model_dir, "{}.pb".format(model_name))
  1025. if os.path.exists(model_name_pth):
  1026. logging.info("model_file is load from pth: {}".format(model_name_pth))
  1027. model_dict = torch.load(model_name_pth, map_location=device)
  1028. else:
  1029. model_dict = cls.convert_tf2torch(model, model_file)
  1030. model.load_state_dict(model_dict)
  1031. else:
  1032. model_dict = torch.load(model_file, map_location=device)
  1033. model.load_state_dict(model_dict)
  1034. if model_name_pth is not None and not os.path.exists(model_name_pth):
  1035. torch.save(model_dict, model_name_pth)
  1036. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  1037. model.to(device)
  1038. return model, args
  1039. @classmethod
  1040. def convert_tf2torch(
  1041. cls,
  1042. model,
  1043. ckpt,
  1044. ):
  1045. logging.info("start convert tf model to torch model")
  1046. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  1047. var_dict_tf = load_tf_dict(ckpt)
  1048. var_dict_torch = model.state_dict()
  1049. var_dict_torch_update = dict()
  1050. # encoder
  1051. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1052. var_dict_torch_update.update(var_dict_torch_update_local)
  1053. # predictor
  1054. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  1055. var_dict_torch_update.update(var_dict_torch_update_local)
  1056. # decoder
  1057. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1058. var_dict_torch_update.update(var_dict_torch_update_local)
  1059. # bias_encoder
  1060. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  1061. var_dict_torch_update.update(var_dict_torch_update_local)
  1062. return var_dict_torch_update
  1063. class ASRTaskMFCCA(ASRTask):
  1064. # If you need more than one optimizers, change this value
  1065. num_optimizers: int = 1
  1066. # Add variable objects configurations
  1067. class_choices_list = [
  1068. # --frontend and --frontend_conf
  1069. frontend_choices,
  1070. # --specaug and --specaug_conf
  1071. specaug_choices,
  1072. # --normalize and --normalize_conf
  1073. normalize_choices,
  1074. # --model and --model_conf
  1075. model_choices,
  1076. # --preencoder and --preencoder_conf
  1077. preencoder_choices,
  1078. # --encoder and --encoder_conf
  1079. encoder_choices,
  1080. # --decoder and --decoder_conf
  1081. decoder_choices,
  1082. ]
  1083. # If you need to modify train() or eval() procedures, change Trainer class here
  1084. trainer = Trainer
  1085. @classmethod
  1086. def build_model(cls, args: argparse.Namespace):
  1087. assert check_argument_types()
  1088. if isinstance(args.token_list, str):
  1089. with open(args.token_list, encoding="utf-8") as f:
  1090. token_list = [line.rstrip() for line in f]
  1091. # Overwriting token_list to keep it as "portable".
  1092. args.token_list = list(token_list)
  1093. elif isinstance(args.token_list, (tuple, list)):
  1094. token_list = list(args.token_list)
  1095. else:
  1096. raise RuntimeError("token_list must be str or list")
  1097. vocab_size = len(token_list)
  1098. logging.info(f"Vocabulary size: {vocab_size}")
  1099. # 1. frontend
  1100. if args.input_size is None:
  1101. # Extract features in the model
  1102. frontend_class = frontend_choices.get_class(args.frontend)
  1103. if args.frontend == 'wav_frontend':
  1104. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1105. else:
  1106. frontend = frontend_class(**args.frontend_conf)
  1107. input_size = frontend.output_size()
  1108. else:
  1109. # Give features from data-loader
  1110. args.frontend = None
  1111. args.frontend_conf = {}
  1112. frontend = None
  1113. input_size = args.input_size
  1114. # 2. Data augmentation for spectrogram
  1115. if args.specaug is not None:
  1116. specaug_class = specaug_choices.get_class(args.specaug)
  1117. specaug = specaug_class(**args.specaug_conf)
  1118. else:
  1119. specaug = None
  1120. # 3. Normalization layer
  1121. if args.normalize is not None:
  1122. normalize_class = normalize_choices.get_class(args.normalize)
  1123. normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
  1124. else:
  1125. normalize = None
  1126. # 4. Pre-encoder input block
  1127. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  1128. if getattr(args, "preencoder", None) is not None:
  1129. preencoder_class = preencoder_choices.get_class(args.preencoder)
  1130. preencoder = preencoder_class(**args.preencoder_conf)
  1131. input_size = preencoder.output_size()
  1132. else:
  1133. preencoder = None
  1134. # 5. Encoder
  1135. encoder_class = encoder_choices.get_class(args.encoder)
  1136. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1137. # 7. Decoder
  1138. decoder_class = decoder_choices.get_class(args.decoder)
  1139. decoder = decoder_class(
  1140. vocab_size=vocab_size,
  1141. encoder_output_size=encoder.output_size(),
  1142. **args.decoder_conf,
  1143. )
  1144. # 8. CTC
  1145. ctc = CTC(
  1146. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  1147. )
  1148. # 10. Build model
  1149. try:
  1150. model_class = model_choices.get_class(args.model)
  1151. except AttributeError:
  1152. model_class = model_choices.get_class("asr")
  1153. rnnt_decoder = None
  1154. # 8. Build model
  1155. model = model_class(
  1156. vocab_size=vocab_size,
  1157. frontend=frontend,
  1158. specaug=specaug,
  1159. normalize=normalize,
  1160. preencoder=preencoder,
  1161. encoder=encoder,
  1162. decoder=decoder,
  1163. ctc=ctc,
  1164. rnnt_decoder=rnnt_decoder,
  1165. token_list=token_list,
  1166. **args.model_conf,
  1167. )
  1168. # 11. Initialize
  1169. if args.init is not None:
  1170. initialize(model, args.init)
  1171. assert check_return_type(model)
  1172. return model
  1173. class ASRTaskAligner(ASRTaskParaformer):
  1174. # If you need more than one optimizers, change this value
  1175. num_optimizers: int = 1
  1176. # Add variable objects configurations
  1177. class_choices_list = [
  1178. # --frontend and --frontend_conf
  1179. frontend_choices,
  1180. # --model and --model_conf
  1181. model_choices,
  1182. # --encoder and --encoder_conf
  1183. encoder_choices,
  1184. # --decoder and --decoder_conf
  1185. decoder_choices,
  1186. ]
  1187. # If you need to modify train() or eval() procedures, change Trainer class here
  1188. trainer = Trainer
  1189. @classmethod
  1190. def build_model(cls, args: argparse.Namespace):
  1191. assert check_argument_types()
  1192. if isinstance(args.token_list, str):
  1193. with open(args.token_list, encoding="utf-8") as f:
  1194. token_list = [line.rstrip() for line in f]
  1195. # Overwriting token_list to keep it as "portable".
  1196. args.token_list = list(token_list)
  1197. elif isinstance(args.token_list, (tuple, list)):
  1198. token_list = list(args.token_list)
  1199. else:
  1200. raise RuntimeError("token_list must be str or list")
  1201. # 1. frontend
  1202. if args.input_size is None:
  1203. # Extract features in the model
  1204. frontend_class = frontend_choices.get_class(args.frontend)
  1205. if args.frontend == 'wav_frontend':
  1206. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1207. else:
  1208. frontend = frontend_class(**args.frontend_conf)
  1209. input_size = frontend.output_size()
  1210. else:
  1211. # Give features from data-loader
  1212. args.frontend = None
  1213. args.frontend_conf = {}
  1214. frontend = None
  1215. input_size = args.input_size
  1216. # 2. Encoder
  1217. encoder_class = encoder_choices.get_class(args.encoder)
  1218. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1219. # 3. Predictor
  1220. predictor_class = predictor_choices.get_class(args.predictor)
  1221. predictor = predictor_class(**args.predictor_conf)
  1222. # 10. Build model
  1223. try:
  1224. model_class = model_choices.get_class(args.model)
  1225. except AttributeError:
  1226. model_class = model_choices.get_class("asr")
  1227. # 8. Build model
  1228. model = model_class(
  1229. frontend=frontend,
  1230. encoder=encoder,
  1231. predictor=predictor,
  1232. token_list=token_list,
  1233. **args.model_conf,
  1234. )
  1235. # 11. Initialize
  1236. if args.init is not None:
  1237. initialize(model, args.init)
  1238. assert check_return_type(model)
  1239. return model
  1240. @classmethod
  1241. def required_data_names(
  1242. cls, train: bool = True, inference: bool = False
  1243. ) -> Tuple[str, ...]:
  1244. retval = ("speech", "text")
  1245. return retval
  1246. class ASRTransducerTask(ASRTask):
  1247. """ASR Transducer Task definition."""
  1248. num_optimizers: int = 1
  1249. class_choices_list = [
  1250. model_choices,
  1251. frontend_choices,
  1252. specaug_choices,
  1253. normalize_choices,
  1254. encoder_choices,
  1255. rnnt_decoder_choices,
  1256. joint_network_choices,
  1257. ]
  1258. trainer = Trainer
  1259. @classmethod
  1260. def build_model(cls, args: argparse.Namespace) -> TransducerModel:
  1261. """Required data depending on task mode.
  1262. Args:
  1263. cls: ASRTransducerTask object.
  1264. args: Task arguments.
  1265. Return:
  1266. model: ASR Transducer model.
  1267. """
  1268. assert check_argument_types()
  1269. if isinstance(args.token_list, str):
  1270. with open(args.token_list, encoding="utf-8") as f:
  1271. token_list = [line.rstrip() for line in f]
  1272. # Overwriting token_list to keep it as "portable".
  1273. args.token_list = list(token_list)
  1274. elif isinstance(args.token_list, (tuple, list)):
  1275. token_list = list(args.token_list)
  1276. else:
  1277. raise RuntimeError("token_list must be str or list")
  1278. vocab_size = len(token_list)
  1279. logging.info(f"Vocabulary size: {vocab_size }")
  1280. # 1. frontend
  1281. if args.input_size is None:
  1282. # Extract features in the model
  1283. frontend_class = frontend_choices.get_class(args.frontend)
  1284. frontend = frontend_class(**args.frontend_conf)
  1285. input_size = frontend.output_size()
  1286. else:
  1287. # Give features from data-loader
  1288. frontend = None
  1289. input_size = args.input_size
  1290. # 2. Data augmentation for spectrogram
  1291. if args.specaug is not None:
  1292. specaug_class = specaug_choices.get_class(args.specaug)
  1293. specaug = specaug_class(**args.specaug_conf)
  1294. else:
  1295. specaug = None
  1296. # 3. Normalization layer
  1297. if args.normalize is not None:
  1298. normalize_class = normalize_choices.get_class(args.normalize)
  1299. normalize = normalize_class(**args.normalize_conf)
  1300. else:
  1301. normalize = None
  1302. # 4. Encoder
  1303. if getattr(args, "encoder", None) is not None:
  1304. encoder_class = encoder_choices.get_class(args.encoder)
  1305. encoder = encoder_class(input_size, **args.encoder_conf)
  1306. else:
  1307. encoder = Encoder(input_size, **args.encoder_conf)
  1308. encoder_output_size = encoder.output_size()
  1309. # 5. Decoder
  1310. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  1311. decoder = rnnt_decoder_class(
  1312. vocab_size,
  1313. **args.rnnt_decoder_conf,
  1314. )
  1315. decoder_output_size = decoder.output_size
  1316. if getattr(args, "decoder", None) is not None:
  1317. att_decoder_class = decoder_choices.get_class(args.decoder)
  1318. att_decoder = att_decoder_class(
  1319. vocab_size=vocab_size,
  1320. encoder_output_size=encoder_output_size,
  1321. **args.decoder_conf,
  1322. )
  1323. else:
  1324. att_decoder = None
  1325. # 6. Joint Network
  1326. joint_network = JointNetwork(
  1327. vocab_size,
  1328. encoder_output_size,
  1329. decoder_output_size,
  1330. **args.joint_network_conf,
  1331. )
  1332. # 7. Build model
  1333. try:
  1334. model_class = model_choices.get_class(args.model)
  1335. except AttributeError:
  1336. model_class = model_choices.get_class("rnnt_unified")
  1337. model = model_class(
  1338. vocab_size=vocab_size,
  1339. token_list=token_list,
  1340. frontend=frontend,
  1341. specaug=specaug,
  1342. normalize=normalize,
  1343. encoder=encoder,
  1344. decoder=decoder,
  1345. att_decoder=att_decoder,
  1346. joint_network=joint_network,
  1347. **args.model_conf,
  1348. )
  1349. # 8. Initialize model
  1350. if args.init is not None:
  1351. raise NotImplementedError(
  1352. "Currently not supported.",
  1353. "Initialization part will be reworked in a short future.",
  1354. )
  1355. #assert check_return_type(model)
  1356. return model