asr.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.ctc import CTC
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.decoder.rnn_decoder import RNNDecoder
  25. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  26. from funasr.models.decoder.transformer_decoder import (
  27. DynamicConvolution2DTransformerDecoder, # noqa: H301
  28. )
  29. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolution2DTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import (
  34. LightweightConvolutionTransformerDecoder, # noqa: H301
  35. )
  36. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  37. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  38. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  39. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  40. from funasr.models.joint_net.joint_network import JointNetwork
  41. from funasr.models.e2e_asr import ESPnetASRModel
  42. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
  43. from funasr.models.e2e_tp import TimestampPredictor
  44. from funasr.models.e2e_asr_mfcca import MFCCA
  45. from funasr.models.e2e_uni_asr import UniASR
  46. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  47. from funasr.models.encoder.abs_encoder import AbsEncoder
  48. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  49. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  50. from funasr.models.encoder.rnn_encoder import RNNEncoder
  51. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  52. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  53. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  54. from funasr.models.frontend.abs_frontend import AbsFrontend
  55. from funasr.models.frontend.default import DefaultFrontend
  56. from funasr.models.frontend.default import MultiChannelFrontend
  57. from funasr.models.frontend.fused import FusedFrontends
  58. from funasr.models.frontend.s3prl import S3prlFrontend
  59. from funasr.models.frontend.wav_frontend import WavFrontend
  60. from funasr.models.frontend.windowing import SlidingWindow
  61. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  62. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  63. HuggingFaceTransformersPostEncoder, # noqa: H301
  64. )
  65. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  66. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  67. from funasr.models.preencoder.linear import LinearProjection
  68. from funasr.models.preencoder.sinc import LightweightSincConvs
  69. from funasr.models.specaug.abs_specaug import AbsSpecAug
  70. from funasr.models.specaug.specaug import SpecAug
  71. from funasr.models.specaug.specaug import SpecAugLFR
  72. from funasr.modules.subsampling import Conv1dSubsampling
  73. from funasr.tasks.abs_task import AbsTask
  74. from funasr.text.phoneme_tokenizer import g2p_choices
  75. from funasr.torch_utils.initialize import initialize
  76. from funasr.train.abs_espnet_model import AbsESPnetModel
  77. from funasr.train.class_choices import ClassChoices
  78. from funasr.train.trainer import Trainer
  79. from funasr.utils.get_default_kwargs import get_default_kwargs
  80. from funasr.utils.nested_dict_action import NestedDictAction
  81. from funasr.utils.types import float_or_none
  82. from funasr.utils.types import int_or_none
  83. from funasr.utils.types import str2bool
  84. from funasr.utils.types import str_or_none
  85. frontend_choices = ClassChoices(
  86. name="frontend",
  87. classes=dict(
  88. default=DefaultFrontend,
  89. sliding_window=SlidingWindow,
  90. s3prl=S3prlFrontend,
  91. fused=FusedFrontends,
  92. wav_frontend=WavFrontend,
  93. multichannelfrontend=MultiChannelFrontend,
  94. ),
  95. type_check=AbsFrontend,
  96. default="default",
  97. )
  98. specaug_choices = ClassChoices(
  99. name="specaug",
  100. classes=dict(
  101. specaug=SpecAug,
  102. specaug_lfr=SpecAugLFR,
  103. ),
  104. type_check=AbsSpecAug,
  105. default=None,
  106. optional=True,
  107. )
  108. normalize_choices = ClassChoices(
  109. "normalize",
  110. classes=dict(
  111. global_mvn=GlobalMVN,
  112. utterance_mvn=UtteranceMVN,
  113. ),
  114. type_check=AbsNormalize,
  115. default=None,
  116. optional=True,
  117. )
  118. model_choices = ClassChoices(
  119. "model",
  120. classes=dict(
  121. asr=ESPnetASRModel,
  122. uniasr=UniASR,
  123. paraformer=Paraformer,
  124. paraformer_bert=ParaformerBert,
  125. bicif_paraformer=BiCifParaformer,
  126. contextual_paraformer=ContextualParaformer,
  127. mfcca=MFCCA,
  128. timestamp_prediction=TimestampPredictor,
  129. ),
  130. type_check=AbsESPnetModel,
  131. default="asr",
  132. )
  133. preencoder_choices = ClassChoices(
  134. name="preencoder",
  135. classes=dict(
  136. sinc=LightweightSincConvs,
  137. linear=LinearProjection,
  138. ),
  139. type_check=AbsPreEncoder,
  140. default=None,
  141. optional=True,
  142. )
  143. encoder_choices = ClassChoices(
  144. "encoder",
  145. classes=dict(
  146. conformer=ConformerEncoder,
  147. transformer=TransformerEncoder,
  148. rnn=RNNEncoder,
  149. sanm=SANMEncoder,
  150. sanm_chunk_opt=SANMEncoderChunkOpt,
  151. data2vec_encoder=Data2VecEncoder,
  152. mfcca_enc=MFCCAEncoder,
  153. chunk_conformer=ConformerChunkEncoder,
  154. ),
  155. type_check=AbsEncoder,
  156. default="rnn",
  157. )
  158. encoder_choices2 = ClassChoices(
  159. "encoder2",
  160. classes=dict(
  161. conformer=ConformerEncoder,
  162. transformer=TransformerEncoder,
  163. rnn=RNNEncoder,
  164. sanm=SANMEncoder,
  165. sanm_chunk_opt=SANMEncoderChunkOpt,
  166. ),
  167. type_check=AbsEncoder,
  168. default="rnn",
  169. )
  170. postencoder_choices = ClassChoices(
  171. name="postencoder",
  172. classes=dict(
  173. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  174. ),
  175. type_check=AbsPostEncoder,
  176. default=None,
  177. optional=True,
  178. )
  179. decoder_choices = ClassChoices(
  180. "decoder",
  181. classes=dict(
  182. transformer=TransformerDecoder,
  183. lightweight_conv=LightweightConvolutionTransformerDecoder,
  184. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  185. dynamic_conv=DynamicConvolutionTransformerDecoder,
  186. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  187. rnn=RNNDecoder,
  188. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  189. paraformer_decoder_sanm=ParaformerSANMDecoder,
  190. paraformer_decoder_san=ParaformerDecoderSAN,
  191. contextual_paraformer_decoder=ContextualParaformerDecoder,
  192. ),
  193. type_check=AbsDecoder,
  194. default="rnn",
  195. )
  196. decoder_choices2 = ClassChoices(
  197. "decoder2",
  198. classes=dict(
  199. transformer=TransformerDecoder,
  200. lightweight_conv=LightweightConvolutionTransformerDecoder,
  201. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  202. dynamic_conv=DynamicConvolutionTransformerDecoder,
  203. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  204. rnn=RNNDecoder,
  205. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  206. paraformer_decoder_sanm=ParaformerSANMDecoder,
  207. ),
  208. type_check=AbsDecoder,
  209. default="rnn",
  210. )
  211. rnnt_decoder_choices = ClassChoices(
  212. "rnnt_decoder",
  213. classes=dict(
  214. rnnt=RNNTDecoder,
  215. ),
  216. type_check=RNNTDecoder,
  217. default="rnnt",
  218. )
  219. predictor_choices = ClassChoices(
  220. name="predictor",
  221. classes=dict(
  222. cif_predictor=CifPredictor,
  223. ctc_predictor=None,
  224. cif_predictor_v2=CifPredictorV2,
  225. cif_predictor_v3=CifPredictorV3,
  226. ),
  227. type_check=None,
  228. default="cif_predictor",
  229. optional=True,
  230. )
  231. predictor_choices2 = ClassChoices(
  232. name="predictor2",
  233. classes=dict(
  234. cif_predictor=CifPredictor,
  235. ctc_predictor=None,
  236. cif_predictor_v2=CifPredictorV2,
  237. ),
  238. type_check=None,
  239. default="cif_predictor",
  240. optional=True,
  241. )
  242. stride_conv_choices = ClassChoices(
  243. name="stride_conv",
  244. classes=dict(
  245. stride_conv1d=Conv1dSubsampling
  246. ),
  247. type_check=None,
  248. default="stride_conv1d",
  249. optional=True,
  250. )
  251. class ASRTask(AbsTask):
  252. # If you need more than one optimizers, change this value
  253. num_optimizers: int = 1
  254. # Add variable objects configurations
  255. class_choices_list = [
  256. # --frontend and --frontend_conf
  257. frontend_choices,
  258. # --specaug and --specaug_conf
  259. specaug_choices,
  260. # --normalize and --normalize_conf
  261. normalize_choices,
  262. # --model and --model_conf
  263. model_choices,
  264. # --preencoder and --preencoder_conf
  265. preencoder_choices,
  266. # --encoder and --encoder_conf
  267. encoder_choices,
  268. # --postencoder and --postencoder_conf
  269. postencoder_choices,
  270. # --decoder and --decoder_conf
  271. decoder_choices,
  272. ]
  273. # If you need to modify train() or eval() procedures, change Trainer class here
  274. trainer = Trainer
  275. @classmethod
  276. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  277. group = parser.add_argument_group(description="Task related")
  278. # NOTE(kamo): add_arguments(..., required=True) can't be used
  279. # to provide --print_config mode. Instead of it, do as
  280. # required = parser.get_default("required")
  281. # required += ["token_list"]
  282. group.add_argument(
  283. "--token_list",
  284. type=str_or_none,
  285. default=None,
  286. help="A text mapping int-id to token",
  287. )
  288. group.add_argument(
  289. "--split_with_space",
  290. type=str2bool,
  291. default=True,
  292. help="whether to split text using <space>",
  293. )
  294. group.add_argument(
  295. "--seg_dict_file",
  296. type=str,
  297. default=None,
  298. help="seg_dict_file for text processing",
  299. )
  300. group.add_argument(
  301. "--init",
  302. type=lambda x: str_or_none(x.lower()),
  303. default=None,
  304. help="The initialization method",
  305. choices=[
  306. "chainer",
  307. "xavier_uniform",
  308. "xavier_normal",
  309. "kaiming_uniform",
  310. "kaiming_normal",
  311. None,
  312. ],
  313. )
  314. group.add_argument(
  315. "--input_size",
  316. type=int_or_none,
  317. default=None,
  318. help="The number of input dimension of the feature",
  319. )
  320. group.add_argument(
  321. "--ctc_conf",
  322. action=NestedDictAction,
  323. default=get_default_kwargs(CTC),
  324. help="The keyword arguments for CTC class.",
  325. )
  326. group.add_argument(
  327. "--joint_net_conf",
  328. action=NestedDictAction,
  329. default=None,
  330. help="The keyword arguments for joint network class.",
  331. )
  332. group = parser.add_argument_group(description="Preprocess related")
  333. group.add_argument(
  334. "--use_preprocessor",
  335. type=str2bool,
  336. default=True,
  337. help="Apply preprocessing to data or not",
  338. )
  339. group.add_argument(
  340. "--token_type",
  341. type=str,
  342. default="bpe",
  343. choices=["bpe", "char", "word", "phn"],
  344. help="The text will be tokenized " "in the specified level token",
  345. )
  346. group.add_argument(
  347. "--bpemodel",
  348. type=str_or_none,
  349. default=None,
  350. help="The model file of sentencepiece",
  351. )
  352. parser.add_argument(
  353. "--non_linguistic_symbols",
  354. type=str_or_none,
  355. default=None,
  356. help="non_linguistic_symbols file path",
  357. )
  358. parser.add_argument(
  359. "--cleaner",
  360. type=str_or_none,
  361. choices=[None, "tacotron", "jaconv", "vietnamese"],
  362. default=None,
  363. help="Apply text cleaning",
  364. )
  365. parser.add_argument(
  366. "--g2p",
  367. type=str_or_none,
  368. choices=g2p_choices,
  369. default=None,
  370. help="Specify g2p method if --token_type=phn",
  371. )
  372. parser.add_argument(
  373. "--speech_volume_normalize",
  374. type=float_or_none,
  375. default=None,
  376. help="Scale the maximum amplitude to the given value.",
  377. )
  378. parser.add_argument(
  379. "--rir_scp",
  380. type=str_or_none,
  381. default=None,
  382. help="The file path of rir scp file.",
  383. )
  384. parser.add_argument(
  385. "--rir_apply_prob",
  386. type=float,
  387. default=1.0,
  388. help="THe probability for applying RIR convolution.",
  389. )
  390. parser.add_argument(
  391. "--cmvn_file",
  392. type=str_or_none,
  393. default=None,
  394. help="The file path of noise scp file.",
  395. )
  396. parser.add_argument(
  397. "--noise_scp",
  398. type=str_or_none,
  399. default=None,
  400. help="The file path of noise scp file.",
  401. )
  402. parser.add_argument(
  403. "--noise_apply_prob",
  404. type=float,
  405. default=1.0,
  406. help="The probability applying Noise adding.",
  407. )
  408. parser.add_argument(
  409. "--noise_db_range",
  410. type=str,
  411. default="13_15",
  412. help="The range of noise decibel level.",
  413. )
  414. for class_choices in cls.class_choices_list:
  415. # Append --<name> and --<name>_conf.
  416. # e.g. --encoder and --encoder_conf
  417. class_choices.add_arguments(group)
  418. @classmethod
  419. def build_collate_fn(
  420. cls, args: argparse.Namespace, train: bool
  421. ) -> Callable[
  422. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  423. Tuple[List[str], Dict[str, torch.Tensor]],
  424. ]:
  425. assert check_argument_types()
  426. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  427. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  428. @classmethod
  429. def build_preprocess_fn(
  430. cls, args: argparse.Namespace, train: bool
  431. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  432. assert check_argument_types()
  433. if args.use_preprocessor:
  434. retval = CommonPreprocessor(
  435. train=train,
  436. token_type=args.token_type,
  437. token_list=args.token_list,
  438. bpemodel=args.bpemodel,
  439. non_linguistic_symbols=args.non_linguistic_symbols,
  440. text_cleaner=args.cleaner,
  441. g2p_type=args.g2p,
  442. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  443. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  444. # NOTE(kamo): Check attribute existence for backward compatibility
  445. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  446. rir_apply_prob=args.rir_apply_prob
  447. if hasattr(args, "rir_apply_prob")
  448. else 1.0,
  449. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  450. noise_apply_prob=args.noise_apply_prob
  451. if hasattr(args, "noise_apply_prob")
  452. else 1.0,
  453. noise_db_range=args.noise_db_range
  454. if hasattr(args, "noise_db_range")
  455. else "13_15",
  456. speech_volume_normalize=args.speech_volume_normalize
  457. if hasattr(args, "rir_scp")
  458. else None,
  459. )
  460. else:
  461. retval = None
  462. assert check_return_type(retval)
  463. return retval
  464. @classmethod
  465. def required_data_names(
  466. cls, train: bool = True, inference: bool = False
  467. ) -> Tuple[str, ...]:
  468. if not inference:
  469. retval = ("speech", "text")
  470. else:
  471. # Recognition mode
  472. retval = ("speech",)
  473. return retval
  474. @classmethod
  475. def optional_data_names(
  476. cls, train: bool = True, inference: bool = False
  477. ) -> Tuple[str, ...]:
  478. retval = ()
  479. assert check_return_type(retval)
  480. return retval
  481. @classmethod
  482. def build_model(cls, args: argparse.Namespace):
  483. assert check_argument_types()
  484. if isinstance(args.token_list, str):
  485. with open(args.token_list, encoding="utf-8") as f:
  486. token_list = [line.rstrip() for line in f]
  487. # Overwriting token_list to keep it as "portable".
  488. args.token_list = list(token_list)
  489. elif isinstance(args.token_list, (tuple, list)):
  490. token_list = list(args.token_list)
  491. else:
  492. raise RuntimeError("token_list must be str or list")
  493. vocab_size = len(token_list)
  494. logging.info(f"Vocabulary size: {vocab_size}")
  495. # 1. frontend
  496. if args.input_size is None:
  497. # Extract features in the model
  498. frontend_class = frontend_choices.get_class(args.frontend)
  499. if args.frontend == 'wav_frontend':
  500. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  501. else:
  502. frontend = frontend_class(**args.frontend_conf)
  503. input_size = frontend.output_size()
  504. else:
  505. # Give features from data-loader
  506. args.frontend = None
  507. args.frontend_conf = {}
  508. frontend = None
  509. input_size = args.input_size
  510. # 2. Data augmentation for spectrogram
  511. if args.specaug is not None:
  512. specaug_class = specaug_choices.get_class(args.specaug)
  513. specaug = specaug_class(**args.specaug_conf)
  514. else:
  515. specaug = None
  516. # 3. Normalization layer
  517. if args.normalize is not None:
  518. normalize_class = normalize_choices.get_class(args.normalize)
  519. normalize = normalize_class(**args.normalize_conf)
  520. else:
  521. normalize = None
  522. # 4. Pre-encoder input block
  523. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  524. if getattr(args, "preencoder", None) is not None:
  525. preencoder_class = preencoder_choices.get_class(args.preencoder)
  526. preencoder = preencoder_class(**args.preencoder_conf)
  527. input_size = preencoder.output_size()
  528. else:
  529. preencoder = None
  530. # 5. Encoder
  531. encoder_class = encoder_choices.get_class(args.encoder)
  532. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  533. # 6. Post-encoder block
  534. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  535. encoder_output_size = encoder.output_size()
  536. if getattr(args, "postencoder", None) is not None:
  537. postencoder_class = postencoder_choices.get_class(args.postencoder)
  538. postencoder = postencoder_class(
  539. input_size=encoder_output_size, **args.postencoder_conf
  540. )
  541. encoder_output_size = postencoder.output_size()
  542. else:
  543. postencoder = None
  544. # 7. Decoder
  545. decoder_class = decoder_choices.get_class(args.decoder)
  546. decoder = decoder_class(
  547. vocab_size=vocab_size,
  548. encoder_output_size=encoder_output_size,
  549. **args.decoder_conf,
  550. )
  551. # 8. CTC
  552. ctc = CTC(
  553. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  554. )
  555. # 9. Build model
  556. try:
  557. model_class = model_choices.get_class(args.model)
  558. except AttributeError:
  559. model_class = model_choices.get_class("asr")
  560. model = model_class(
  561. vocab_size=vocab_size,
  562. frontend=frontend,
  563. specaug=specaug,
  564. normalize=normalize,
  565. preencoder=preencoder,
  566. encoder=encoder,
  567. postencoder=postencoder,
  568. decoder=decoder,
  569. ctc=ctc,
  570. token_list=token_list,
  571. **args.model_conf,
  572. )
  573. # 10. Initialize
  574. if args.init is not None:
  575. initialize(model, args.init)
  576. assert check_return_type(model)
  577. return model
  578. class ASRTaskUniASR(ASRTask):
  579. # If you need more than one optimizers, change this value
  580. num_optimizers: int = 1
  581. # Add variable objects configurations
  582. class_choices_list = [
  583. # --frontend and --frontend_conf
  584. frontend_choices,
  585. # --specaug and --specaug_conf
  586. specaug_choices,
  587. # --normalize and --normalize_conf
  588. normalize_choices,
  589. # --model and --model_conf
  590. model_choices,
  591. # --preencoder and --preencoder_conf
  592. preencoder_choices,
  593. # --encoder and --encoder_conf
  594. encoder_choices,
  595. # --postencoder and --postencoder_conf
  596. postencoder_choices,
  597. # --decoder and --decoder_conf
  598. decoder_choices,
  599. # --predictor and --predictor_conf
  600. predictor_choices,
  601. # --encoder2 and --encoder2_conf
  602. encoder_choices2,
  603. # --decoder2 and --decoder2_conf
  604. decoder_choices2,
  605. # --predictor2 and --predictor2_conf
  606. predictor_choices2,
  607. # --stride_conv and --stride_conv_conf
  608. stride_conv_choices,
  609. ]
  610. # If you need to modify train() or eval() procedures, change Trainer class here
  611. trainer = Trainer
  612. @classmethod
  613. def build_model(cls, args: argparse.Namespace):
  614. assert check_argument_types()
  615. if isinstance(args.token_list, str):
  616. with open(args.token_list, encoding="utf-8") as f:
  617. token_list = [line.rstrip() for line in f]
  618. # Overwriting token_list to keep it as "portable".
  619. args.token_list = list(token_list)
  620. elif isinstance(args.token_list, (tuple, list)):
  621. token_list = list(args.token_list)
  622. else:
  623. raise RuntimeError("token_list must be str or list")
  624. vocab_size = len(token_list)
  625. logging.info(f"Vocabulary size: {vocab_size}")
  626. # 1. frontend
  627. if args.input_size is None:
  628. # Extract features in the model
  629. frontend_class = frontend_choices.get_class(args.frontend)
  630. if args.frontend == 'wav_frontend':
  631. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  632. else:
  633. frontend = frontend_class(**args.frontend_conf)
  634. input_size = frontend.output_size()
  635. else:
  636. # Give features from data-loader
  637. args.frontend = None
  638. args.frontend_conf = {}
  639. frontend = None
  640. input_size = args.input_size
  641. # 2. Data augmentation for spectrogram
  642. if args.specaug is not None:
  643. specaug_class = specaug_choices.get_class(args.specaug)
  644. specaug = specaug_class(**args.specaug_conf)
  645. else:
  646. specaug = None
  647. # 3. Normalization layer
  648. if args.normalize is not None:
  649. normalize_class = normalize_choices.get_class(args.normalize)
  650. normalize = normalize_class(**args.normalize_conf)
  651. else:
  652. normalize = None
  653. # 4. Pre-encoder input block
  654. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  655. if getattr(args, "preencoder", None) is not None:
  656. preencoder_class = preencoder_choices.get_class(args.preencoder)
  657. preencoder = preencoder_class(**args.preencoder_conf)
  658. input_size = preencoder.output_size()
  659. else:
  660. preencoder = None
  661. # 5. Encoder
  662. encoder_class = encoder_choices.get_class(args.encoder)
  663. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  664. encoder_output_size = encoder.output_size()
  665. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  666. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder_output_size,
  667. odim=input_size + encoder_output_size)
  668. stride_conv_output_size = stride_conv.output_size()
  669. # 6. Encoder2
  670. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  671. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  672. # 7. Post-encoder block
  673. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  674. encoder_output_size2 = encoder2.output_size()
  675. if getattr(args, "postencoder", None) is not None:
  676. postencoder_class = postencoder_choices.get_class(args.postencoder)
  677. postencoder = postencoder_class(
  678. input_size=encoder_output_size, **args.postencoder_conf
  679. )
  680. encoder_output_size = postencoder.output_size()
  681. else:
  682. postencoder = None
  683. # 8. Decoder & Decoder2
  684. decoder_class = decoder_choices.get_class(args.decoder)
  685. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  686. decoder = decoder_class(
  687. vocab_size=vocab_size,
  688. encoder_output_size=encoder_output_size,
  689. **args.decoder_conf,
  690. )
  691. decoder2 = decoder_class2(
  692. vocab_size=vocab_size,
  693. encoder_output_size=encoder_output_size2,
  694. **args.decoder2_conf,
  695. )
  696. # 9. CTC
  697. ctc = CTC(
  698. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  699. )
  700. ctc2 = CTC(
  701. odim=vocab_size, encoder_output_size=encoder_output_size2, **args.ctc_conf
  702. )
  703. # 10. Predictor
  704. predictor_class = predictor_choices.get_class(args.predictor)
  705. predictor = predictor_class(**args.predictor_conf)
  706. predictor_class = predictor_choices2.get_class(args.predictor2)
  707. predictor2 = predictor_class(**args.predictor2_conf)
  708. # 11. Build model
  709. try:
  710. model_class = model_choices.get_class(args.model)
  711. except AttributeError:
  712. model_class = model_choices.get_class("asr")
  713. model = model_class(
  714. vocab_size=vocab_size,
  715. frontend=frontend,
  716. specaug=specaug,
  717. normalize=normalize,
  718. preencoder=preencoder,
  719. encoder=encoder,
  720. postencoder=postencoder,
  721. decoder=decoder,
  722. ctc=ctc,
  723. token_list=token_list,
  724. predictor=predictor,
  725. ctc2=ctc2,
  726. encoder2=encoder2,
  727. decoder2=decoder2,
  728. predictor2=predictor2,
  729. stride_conv=stride_conv,
  730. **args.model_conf,
  731. )
  732. # 12. Initialize
  733. if args.init is not None:
  734. initialize(model, args.init)
  735. assert check_return_type(model)
  736. return model
  737. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  738. @classmethod
  739. def build_model_from_file(
  740. cls,
  741. config_file: Union[Path, str] = None,
  742. model_file: Union[Path, str] = None,
  743. cmvn_file: Union[Path, str] = None,
  744. device: str = "cpu",
  745. ):
  746. """Build model from the files.
  747. This method is used for inference or fine-tuning.
  748. Args:
  749. config_file: The yaml file saved when training.
  750. model_file: The model file saved when training.
  751. device: Device type, "cpu", "cuda", or "cuda:N".
  752. """
  753. assert check_argument_types()
  754. if config_file is None:
  755. assert model_file is not None, (
  756. "The argument 'model_file' must be provided "
  757. "if the argument 'config_file' is not specified."
  758. )
  759. config_file = Path(model_file).parent / "config.yaml"
  760. else:
  761. config_file = Path(config_file)
  762. with config_file.open("r", encoding="utf-8") as f:
  763. args = yaml.safe_load(f)
  764. if cmvn_file is not None:
  765. args["cmvn_file"] = cmvn_file
  766. args = argparse.Namespace(**args)
  767. model = cls.build_model(args)
  768. if not isinstance(model, AbsESPnetModel):
  769. raise RuntimeError(
  770. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  771. )
  772. model.to(device)
  773. model_dict = dict()
  774. model_name_pth = None
  775. if model_file is not None:
  776. logging.info("model_file is {}".format(model_file))
  777. if device == "cuda":
  778. device = f"cuda:{torch.cuda.current_device()}"
  779. model_dir = os.path.dirname(model_file)
  780. model_name = os.path.basename(model_file)
  781. if "model.ckpt-" in model_name or ".bin" in model_name:
  782. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  783. '.pb')) if ".bin" in model_name else os.path.join(
  784. model_dir, "{}.pb".format(model_name))
  785. if os.path.exists(model_name_pth):
  786. logging.info("model_file is load from pth: {}".format(model_name_pth))
  787. model_dict = torch.load(model_name_pth, map_location=device)
  788. else:
  789. model_dict = cls.convert_tf2torch(model, model_file)
  790. model.load_state_dict(model_dict)
  791. else:
  792. model_dict = torch.load(model_file, map_location=device)
  793. model.load_state_dict(model_dict)
  794. if model_name_pth is not None and not os.path.exists(model_name_pth):
  795. torch.save(model_dict, model_name_pth)
  796. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  797. return model, args
  798. @classmethod
  799. def convert_tf2torch(
  800. cls,
  801. model,
  802. ckpt,
  803. ):
  804. logging.info("start convert tf model to torch model")
  805. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  806. var_dict_tf = load_tf_dict(ckpt)
  807. var_dict_torch = model.state_dict()
  808. var_dict_torch_update = dict()
  809. # encoder
  810. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  811. var_dict_torch_update.update(var_dict_torch_update_local)
  812. # predictor
  813. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  814. var_dict_torch_update.update(var_dict_torch_update_local)
  815. # decoder
  816. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  817. var_dict_torch_update.update(var_dict_torch_update_local)
  818. # encoder2
  819. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  820. var_dict_torch_update.update(var_dict_torch_update_local)
  821. # predictor2
  822. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  823. var_dict_torch_update.update(var_dict_torch_update_local)
  824. # decoder2
  825. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  826. var_dict_torch_update.update(var_dict_torch_update_local)
  827. # stride_conv
  828. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  829. var_dict_torch_update.update(var_dict_torch_update_local)
  830. return var_dict_torch_update
  831. class ASRTaskParaformer(ASRTask):
  832. # If you need more than one optimizers, change this value
  833. num_optimizers: int = 1
  834. # Add variable objects configurations
  835. class_choices_list = [
  836. # --frontend and --frontend_conf
  837. frontend_choices,
  838. # --specaug and --specaug_conf
  839. specaug_choices,
  840. # --normalize and --normalize_conf
  841. normalize_choices,
  842. # --model and --model_conf
  843. model_choices,
  844. # --preencoder and --preencoder_conf
  845. preencoder_choices,
  846. # --encoder and --encoder_conf
  847. encoder_choices,
  848. # --postencoder and --postencoder_conf
  849. postencoder_choices,
  850. # --decoder and --decoder_conf
  851. decoder_choices,
  852. # --predictor and --predictor_conf
  853. predictor_choices,
  854. ]
  855. # If you need to modify train() or eval() procedures, change Trainer class here
  856. trainer = Trainer
  857. @classmethod
  858. def build_model(cls, args: argparse.Namespace):
  859. assert check_argument_types()
  860. if isinstance(args.token_list, str):
  861. with open(args.token_list, encoding="utf-8") as f:
  862. token_list = [line.rstrip() for line in f]
  863. # Overwriting token_list to keep it as "portable".
  864. args.token_list = list(token_list)
  865. elif isinstance(args.token_list, (tuple, list)):
  866. token_list = list(args.token_list)
  867. else:
  868. raise RuntimeError("token_list must be str or list")
  869. vocab_size = len(token_list)
  870. logging.info(f"Vocabulary size: {vocab_size}")
  871. # 1. frontend
  872. if args.input_size is None:
  873. # Extract features in the model
  874. frontend_class = frontend_choices.get_class(args.frontend)
  875. if args.frontend == 'wav_frontend':
  876. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  877. else:
  878. frontend = frontend_class(**args.frontend_conf)
  879. input_size = frontend.output_size()
  880. else:
  881. # Give features from data-loader
  882. args.frontend = None
  883. args.frontend_conf = {}
  884. frontend = None
  885. input_size = args.input_size
  886. # 2. Data augmentation for spectrogram
  887. if args.specaug is not None:
  888. specaug_class = specaug_choices.get_class(args.specaug)
  889. specaug = specaug_class(**args.specaug_conf)
  890. else:
  891. specaug = None
  892. # 3. Normalization layer
  893. if args.normalize is not None:
  894. normalize_class = normalize_choices.get_class(args.normalize)
  895. normalize = normalize_class(**args.normalize_conf)
  896. else:
  897. normalize = None
  898. # 4. Pre-encoder input block
  899. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  900. if getattr(args, "preencoder", None) is not None:
  901. preencoder_class = preencoder_choices.get_class(args.preencoder)
  902. preencoder = preencoder_class(**args.preencoder_conf)
  903. input_size = preencoder.output_size()
  904. else:
  905. preencoder = None
  906. # 5. Encoder
  907. encoder_class = encoder_choices.get_class(args.encoder)
  908. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  909. # 6. Post-encoder block
  910. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  911. encoder_output_size = encoder.output_size()
  912. if getattr(args, "postencoder", None) is not None:
  913. postencoder_class = postencoder_choices.get_class(args.postencoder)
  914. postencoder = postencoder_class(
  915. input_size=encoder_output_size, **args.postencoder_conf
  916. )
  917. encoder_output_size = postencoder.output_size()
  918. else:
  919. postencoder = None
  920. # 7. Decoder
  921. decoder_class = decoder_choices.get_class(args.decoder)
  922. decoder = decoder_class(
  923. vocab_size=vocab_size,
  924. encoder_output_size=encoder_output_size,
  925. **args.decoder_conf,
  926. )
  927. # 8. CTC
  928. ctc = CTC(
  929. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  930. )
  931. # 9. Predictor
  932. predictor_class = predictor_choices.get_class(args.predictor)
  933. predictor = predictor_class(**args.predictor_conf)
  934. # 10. Build model
  935. try:
  936. model_class = model_choices.get_class(args.model)
  937. except AttributeError:
  938. model_class = model_choices.get_class("asr")
  939. model = model_class(
  940. vocab_size=vocab_size,
  941. frontend=frontend,
  942. specaug=specaug,
  943. normalize=normalize,
  944. preencoder=preencoder,
  945. encoder=encoder,
  946. postencoder=postencoder,
  947. decoder=decoder,
  948. ctc=ctc,
  949. token_list=token_list,
  950. predictor=predictor,
  951. **args.model_conf,
  952. )
  953. # 11. Initialize
  954. if args.init is not None:
  955. initialize(model, args.init)
  956. assert check_return_type(model)
  957. return model
  958. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  959. @classmethod
  960. def build_model_from_file(
  961. cls,
  962. config_file: Union[Path, str] = None,
  963. model_file: Union[Path, str] = None,
  964. cmvn_file: Union[Path, str] = None,
  965. device: str = "cpu",
  966. ):
  967. """Build model from the files.
  968. This method is used for inference or fine-tuning.
  969. Args:
  970. config_file: The yaml file saved when training.
  971. model_file: The model file saved when training.
  972. device: Device type, "cpu", "cuda", or "cuda:N".
  973. """
  974. assert check_argument_types()
  975. if config_file is None:
  976. assert model_file is not None, (
  977. "The argument 'model_file' must be provided "
  978. "if the argument 'config_file' is not specified."
  979. )
  980. config_file = Path(model_file).parent / "config.yaml"
  981. else:
  982. config_file = Path(config_file)
  983. with config_file.open("r", encoding="utf-8") as f:
  984. args = yaml.safe_load(f)
  985. if cmvn_file is not None:
  986. args["cmvn_file"] = cmvn_file
  987. args = argparse.Namespace(**args)
  988. model = cls.build_model(args)
  989. if not isinstance(model, AbsESPnetModel):
  990. raise RuntimeError(
  991. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  992. )
  993. model.to(device)
  994. model_dict = dict()
  995. model_name_pth = None
  996. if model_file is not None:
  997. logging.info("model_file is {}".format(model_file))
  998. if device == "cuda":
  999. device = f"cuda:{torch.cuda.current_device()}"
  1000. model_dir = os.path.dirname(model_file)
  1001. model_name = os.path.basename(model_file)
  1002. if "model.ckpt-" in model_name or ".bin" in model_name:
  1003. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  1004. '.pb')) if ".bin" in model_name else os.path.join(
  1005. model_dir, "{}.pb".format(model_name))
  1006. if os.path.exists(model_name_pth):
  1007. logging.info("model_file is load from pth: {}".format(model_name_pth))
  1008. model_dict = torch.load(model_name_pth, map_location=device)
  1009. else:
  1010. model_dict = cls.convert_tf2torch(model, model_file)
  1011. model.load_state_dict(model_dict)
  1012. else:
  1013. model_dict = torch.load(model_file, map_location=device)
  1014. model.load_state_dict(model_dict)
  1015. if model_name_pth is not None and not os.path.exists(model_name_pth):
  1016. torch.save(model_dict, model_name_pth)
  1017. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  1018. model.to(device)
  1019. return model, args
  1020. @classmethod
  1021. def convert_tf2torch(
  1022. cls,
  1023. model,
  1024. ckpt,
  1025. ):
  1026. logging.info("start convert tf model to torch model")
  1027. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  1028. var_dict_tf = load_tf_dict(ckpt)
  1029. var_dict_torch = model.state_dict()
  1030. var_dict_torch_update = dict()
  1031. # encoder
  1032. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1033. var_dict_torch_update.update(var_dict_torch_update_local)
  1034. # predictor
  1035. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  1036. var_dict_torch_update.update(var_dict_torch_update_local)
  1037. # decoder
  1038. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1039. var_dict_torch_update.update(var_dict_torch_update_local)
  1040. # bias_encoder
  1041. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  1042. var_dict_torch_update.update(var_dict_torch_update_local)
  1043. return var_dict_torch_update
  1044. class ASRTaskMFCCA(ASRTask):
  1045. # If you need more than one optimizers, change this value
  1046. num_optimizers: int = 1
  1047. # Add variable objects configurations
  1048. class_choices_list = [
  1049. # --frontend and --frontend_conf
  1050. frontend_choices,
  1051. # --specaug and --specaug_conf
  1052. specaug_choices,
  1053. # --normalize and --normalize_conf
  1054. normalize_choices,
  1055. # --model and --model_conf
  1056. model_choices,
  1057. # --preencoder and --preencoder_conf
  1058. preencoder_choices,
  1059. # --encoder and --encoder_conf
  1060. encoder_choices,
  1061. # --decoder and --decoder_conf
  1062. decoder_choices,
  1063. ]
  1064. # If you need to modify train() or eval() procedures, change Trainer class here
  1065. trainer = Trainer
  1066. @classmethod
  1067. def build_model(cls, args: argparse.Namespace):
  1068. assert check_argument_types()
  1069. if isinstance(args.token_list, str):
  1070. with open(args.token_list, encoding="utf-8") as f:
  1071. token_list = [line.rstrip() for line in f]
  1072. # Overwriting token_list to keep it as "portable".
  1073. args.token_list = list(token_list)
  1074. elif isinstance(args.token_list, (tuple, list)):
  1075. token_list = list(args.token_list)
  1076. else:
  1077. raise RuntimeError("token_list must be str or list")
  1078. vocab_size = len(token_list)
  1079. logging.info(f"Vocabulary size: {vocab_size}")
  1080. # 1. frontend
  1081. if args.input_size is None:
  1082. # Extract features in the model
  1083. frontend_class = frontend_choices.get_class(args.frontend)
  1084. if args.frontend == 'wav_frontend':
  1085. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1086. else:
  1087. frontend = frontend_class(**args.frontend_conf)
  1088. input_size = frontend.output_size()
  1089. else:
  1090. # Give features from data-loader
  1091. args.frontend = None
  1092. args.frontend_conf = {}
  1093. frontend = None
  1094. input_size = args.input_size
  1095. # 2. Data augmentation for spectrogram
  1096. if args.specaug is not None:
  1097. specaug_class = specaug_choices.get_class(args.specaug)
  1098. specaug = specaug_class(**args.specaug_conf)
  1099. else:
  1100. specaug = None
  1101. # 3. Normalization layer
  1102. if args.normalize is not None:
  1103. normalize_class = normalize_choices.get_class(args.normalize)
  1104. normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
  1105. else:
  1106. normalize = None
  1107. # 4. Pre-encoder input block
  1108. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  1109. if getattr(args, "preencoder", None) is not None:
  1110. preencoder_class = preencoder_choices.get_class(args.preencoder)
  1111. preencoder = preencoder_class(**args.preencoder_conf)
  1112. input_size = preencoder.output_size()
  1113. else:
  1114. preencoder = None
  1115. # 5. Encoder
  1116. encoder_class = encoder_choices.get_class(args.encoder)
  1117. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1118. # 7. Decoder
  1119. decoder_class = decoder_choices.get_class(args.decoder)
  1120. decoder = decoder_class(
  1121. vocab_size=vocab_size,
  1122. encoder_output_size=encoder.output_size(),
  1123. **args.decoder_conf,
  1124. )
  1125. # 8. CTC
  1126. ctc = CTC(
  1127. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  1128. )
  1129. # 10. Build model
  1130. try:
  1131. model_class = model_choices.get_class(args.model)
  1132. except AttributeError:
  1133. model_class = model_choices.get_class("asr")
  1134. rnnt_decoder = None
  1135. # 8. Build model
  1136. model = model_class(
  1137. vocab_size=vocab_size,
  1138. frontend=frontend,
  1139. specaug=specaug,
  1140. normalize=normalize,
  1141. preencoder=preencoder,
  1142. encoder=encoder,
  1143. decoder=decoder,
  1144. ctc=ctc,
  1145. rnnt_decoder=rnnt_decoder,
  1146. token_list=token_list,
  1147. **args.model_conf,
  1148. )
  1149. # 11. Initialize
  1150. if args.init is not None:
  1151. initialize(model, args.init)
  1152. assert check_return_type(model)
  1153. return model
  1154. class ASRTaskAligner(ASRTaskParaformer):
  1155. # If you need more than one optimizers, change this value
  1156. num_optimizers: int = 1
  1157. # Add variable objects configurations
  1158. class_choices_list = [
  1159. # --frontend and --frontend_conf
  1160. frontend_choices,
  1161. # --model and --model_conf
  1162. model_choices,
  1163. # --encoder and --encoder_conf
  1164. encoder_choices,
  1165. # --decoder and --decoder_conf
  1166. decoder_choices,
  1167. ]
  1168. # If you need to modify train() or eval() procedures, change Trainer class here
  1169. trainer = Trainer
  1170. @classmethod
  1171. def build_model(cls, args: argparse.Namespace):
  1172. assert check_argument_types()
  1173. if isinstance(args.token_list, str):
  1174. with open(args.token_list, encoding="utf-8") as f:
  1175. token_list = [line.rstrip() for line in f]
  1176. # Overwriting token_list to keep it as "portable".
  1177. args.token_list = list(token_list)
  1178. elif isinstance(args.token_list, (tuple, list)):
  1179. token_list = list(args.token_list)
  1180. else:
  1181. raise RuntimeError("token_list must be str or list")
  1182. # 1. frontend
  1183. if args.input_size is None:
  1184. # Extract features in the model
  1185. frontend_class = frontend_choices.get_class(args.frontend)
  1186. if args.frontend == 'wav_frontend':
  1187. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1188. else:
  1189. frontend = frontend_class(**args.frontend_conf)
  1190. input_size = frontend.output_size()
  1191. else:
  1192. # Give features from data-loader
  1193. args.frontend = None
  1194. args.frontend_conf = {}
  1195. frontend = None
  1196. input_size = args.input_size
  1197. # 2. Encoder
  1198. encoder_class = encoder_choices.get_class(args.encoder)
  1199. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1200. # 3. Predictor
  1201. predictor_class = predictor_choices.get_class(args.predictor)
  1202. predictor = predictor_class(**args.predictor_conf)
  1203. # 10. Build model
  1204. try:
  1205. model_class = model_choices.get_class(args.model)
  1206. except AttributeError:
  1207. model_class = model_choices.get_class("asr")
  1208. # 8. Build model
  1209. model = model_class(
  1210. frontend=frontend,
  1211. encoder=encoder,
  1212. predictor=predictor,
  1213. token_list=token_list,
  1214. **args.model_conf,
  1215. )
  1216. # 11. Initialize
  1217. if args.init is not None:
  1218. initialize(model, args.init)
  1219. assert check_return_type(model)
  1220. return model
  1221. @classmethod
  1222. def required_data_names(
  1223. cls, train: bool = True, inference: bool = False
  1224. ) -> Tuple[str, ...]:
  1225. retval = ("speech", "text")
  1226. return retval
  1227. class ASRTransducerTask(AbsTask):
  1228. """ASR Transducer Task definition."""
  1229. num_optimizers: int = 1
  1230. class_choices_list = [
  1231. frontend_choices,
  1232. specaug_choices,
  1233. normalize_choices,
  1234. encoder_choices,
  1235. rnnt_decoder_choices,
  1236. ]
  1237. trainer = Trainer
  1238. @classmethod
  1239. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  1240. """Add Transducer task arguments.
  1241. Args:
  1242. cls: ASRTransducerTask object.
  1243. parser: Transducer arguments parser.
  1244. """
  1245. group = parser.add_argument_group(description="Task related.")
  1246. # required = parser.get_default("required")
  1247. # required += ["token_list"]
  1248. group.add_argument(
  1249. "--token_list",
  1250. type=str_or_none,
  1251. default=None,
  1252. help="Integer-string mapper for tokens.",
  1253. )
  1254. group.add_argument(
  1255. "--split_with_space",
  1256. type=str2bool,
  1257. default=True,
  1258. help="whether to split text using <space>",
  1259. )
  1260. group.add_argument(
  1261. "--input_size",
  1262. type=int_or_none,
  1263. default=None,
  1264. help="The number of dimensions for input features.",
  1265. )
  1266. group.add_argument(
  1267. "--init",
  1268. type=str_or_none,
  1269. default=None,
  1270. help="Type of model initialization to use.",
  1271. )
  1272. group.add_argument(
  1273. "--model_conf",
  1274. action=NestedDictAction,
  1275. default=get_default_kwargs(TransducerModel),
  1276. help="The keyword arguments for the model class.",
  1277. )
  1278. # group.add_argument(
  1279. # "--encoder_conf",
  1280. # action=NestedDictAction,
  1281. # default={},
  1282. # help="The keyword arguments for the encoder class.",
  1283. # )
  1284. group.add_argument(
  1285. "--joint_network_conf",
  1286. action=NestedDictAction,
  1287. default={},
  1288. help="The keyword arguments for the joint network class.",
  1289. )
  1290. group = parser.add_argument_group(description="Preprocess related.")
  1291. group.add_argument(
  1292. "--use_preprocessor",
  1293. type=str2bool,
  1294. default=True,
  1295. help="Whether to apply preprocessing to input data.",
  1296. )
  1297. group.add_argument(
  1298. "--token_type",
  1299. type=str,
  1300. default="bpe",
  1301. choices=["bpe", "char", "word", "phn"],
  1302. help="The type of tokens to use during tokenization.",
  1303. )
  1304. group.add_argument(
  1305. "--bpemodel",
  1306. type=str_or_none,
  1307. default=None,
  1308. help="The path of the sentencepiece model.",
  1309. )
  1310. parser.add_argument(
  1311. "--non_linguistic_symbols",
  1312. type=str_or_none,
  1313. help="The 'non_linguistic_symbols' file path.",
  1314. )
  1315. parser.add_argument(
  1316. "--cleaner",
  1317. type=str_or_none,
  1318. choices=[None, "tacotron", "jaconv", "vietnamese"],
  1319. default=None,
  1320. help="Text cleaner to use.",
  1321. )
  1322. parser.add_argument(
  1323. "--g2p",
  1324. type=str_or_none,
  1325. choices=g2p_choices,
  1326. default=None,
  1327. help="g2p method to use if --token_type=phn.",
  1328. )
  1329. parser.add_argument(
  1330. "--speech_volume_normalize",
  1331. type=float_or_none,
  1332. default=None,
  1333. help="Normalization value for maximum amplitude scaling.",
  1334. )
  1335. parser.add_argument(
  1336. "--rir_scp",
  1337. type=str_or_none,
  1338. default=None,
  1339. help="The RIR SCP file path.",
  1340. )
  1341. parser.add_argument(
  1342. "--rir_apply_prob",
  1343. type=float,
  1344. default=1.0,
  1345. help="The probability of the applied RIR convolution.",
  1346. )
  1347. parser.add_argument(
  1348. "--noise_scp",
  1349. type=str_or_none,
  1350. default=None,
  1351. help="The path of noise SCP file.",
  1352. )
  1353. parser.add_argument(
  1354. "--noise_apply_prob",
  1355. type=float,
  1356. default=1.0,
  1357. help="The probability of the applied noise addition.",
  1358. )
  1359. parser.add_argument(
  1360. "--noise_db_range",
  1361. type=str,
  1362. default="13_15",
  1363. help="The range of the noise decibel level.",
  1364. )
  1365. for class_choices in cls.class_choices_list:
  1366. # Append --<name> and --<name>_conf.
  1367. # e.g. --decoder and --decoder_conf
  1368. class_choices.add_arguments(group)
  1369. @classmethod
  1370. def build_collate_fn(
  1371. cls, args: argparse.Namespace, train: bool
  1372. ) -> Callable[
  1373. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  1374. Tuple[List[str], Dict[str, torch.Tensor]],
  1375. ]:
  1376. """Build collate function.
  1377. Args:
  1378. cls: ASRTransducerTask object.
  1379. args: Task arguments.
  1380. train: Training mode.
  1381. Return:
  1382. : Callable collate function.
  1383. """
  1384. assert check_argument_types()
  1385. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  1386. @classmethod
  1387. def build_preprocess_fn(
  1388. cls, args: argparse.Namespace, train: bool
  1389. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  1390. """Build pre-processing function.
  1391. Args:
  1392. cls: ASRTransducerTask object.
  1393. args: Task arguments.
  1394. train: Training mode.
  1395. Return:
  1396. : Callable pre-processing function.
  1397. """
  1398. assert check_argument_types()
  1399. if args.use_preprocessor:
  1400. retval = CommonPreprocessor(
  1401. train=train,
  1402. token_type=args.token_type,
  1403. token_list=args.token_list,
  1404. bpemodel=args.bpemodel,
  1405. non_linguistic_symbols=args.non_linguistic_symbols,
  1406. text_cleaner=args.cleaner,
  1407. g2p_type=args.g2p,
  1408. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  1409. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  1410. rir_apply_prob=args.rir_apply_prob
  1411. if hasattr(args, "rir_apply_prob")
  1412. else 1.0,
  1413. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  1414. noise_apply_prob=args.noise_apply_prob
  1415. if hasattr(args, "noise_apply_prob")
  1416. else 1.0,
  1417. noise_db_range=args.noise_db_range
  1418. if hasattr(args, "noise_db_range")
  1419. else "13_15",
  1420. speech_volume_normalize=args.speech_volume_normalize
  1421. if hasattr(args, "rir_scp")
  1422. else None,
  1423. )
  1424. else:
  1425. retval = None
  1426. assert check_return_type(retval)
  1427. return retval
  1428. @classmethod
  1429. def required_data_names(
  1430. cls, train: bool = True, inference: bool = False
  1431. ) -> Tuple[str, ...]:
  1432. """Required data depending on task mode.
  1433. Args:
  1434. cls: ASRTransducerTask object.
  1435. train: Training mode.
  1436. inference: Inference mode.
  1437. Return:
  1438. retval: Required task data.
  1439. """
  1440. if not inference:
  1441. retval = ("speech", "text")
  1442. else:
  1443. retval = ("speech",)
  1444. return retval
  1445. @classmethod
  1446. def optional_data_names(
  1447. cls, train: bool = True, inference: bool = False
  1448. ) -> Tuple[str, ...]:
  1449. """Optional data depending on task mode.
  1450. Args:
  1451. cls: ASRTransducerTask object.
  1452. train: Training mode.
  1453. inference: Inference mode.
  1454. Return:
  1455. retval: Optional task data.
  1456. """
  1457. retval = ()
  1458. assert check_return_type(retval)
  1459. return retval
  1460. @classmethod
  1461. def build_model(cls, args: argparse.Namespace) -> TransducerModel:
  1462. """Required data depending on task mode.
  1463. Args:
  1464. cls: ASRTransducerTask object.
  1465. args: Task arguments.
  1466. Return:
  1467. model: ASR Transducer model.
  1468. """
  1469. assert check_argument_types()
  1470. if isinstance(args.token_list, str):
  1471. with open(args.token_list, encoding="utf-8") as f:
  1472. token_list = [line.rstrip() for line in f]
  1473. # Overwriting token_list to keep it as "portable".
  1474. args.token_list = list(token_list)
  1475. elif isinstance(args.token_list, (tuple, list)):
  1476. token_list = list(args.token_list)
  1477. else:
  1478. raise RuntimeError("token_list must be str or list")
  1479. vocab_size = len(token_list)
  1480. logging.info(f"Vocabulary size: {vocab_size }")
  1481. # 1. frontend
  1482. if args.input_size is None:
  1483. # Extract features in the model
  1484. frontend_class = frontend_choices.get_class(args.frontend)
  1485. frontend = frontend_class(**args.frontend_conf)
  1486. input_size = frontend.output_size()
  1487. else:
  1488. # Give features from data-loader
  1489. frontend = None
  1490. input_size = args.input_size
  1491. # 2. Data augmentation for spectrogram
  1492. if args.specaug is not None:
  1493. specaug_class = specaug_choices.get_class(args.specaug)
  1494. specaug = specaug_class(**args.specaug_conf)
  1495. else:
  1496. specaug = None
  1497. # 3. Normalization layer
  1498. if args.normalize is not None:
  1499. normalize_class = normalize_choices.get_class(args.normalize)
  1500. normalize = normalize_class(**args.normalize_conf)
  1501. else:
  1502. normalize = None
  1503. # 4. Encoder
  1504. if getattr(args, "encoder", None) is not None:
  1505. encoder_class = encoder_choices.get_class(args.encoder)
  1506. encoder = encoder_class(input_size, **args.encoder_conf)
  1507. else:
  1508. encoder = Encoder(input_size, **args.encoder_conf)
  1509. encoder_output_size = encoder.output_size()
  1510. # 5. Decoder
  1511. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  1512. decoder = rnnt_decoder_class(
  1513. vocab_size,
  1514. **args.rnnt_decoder_conf,
  1515. )
  1516. decoder_output_size = decoder.output_size
  1517. if getattr(args, "decoder", None) is not None:
  1518. att_decoder_class = decoder_choices.get_class(args.att_decoder)
  1519. att_decoder = att_decoder_class(
  1520. vocab_size=vocab_size,
  1521. encoder_output_size=encoder_output_size,
  1522. **args.decoder_conf,
  1523. )
  1524. else:
  1525. att_decoder = None
  1526. # 6. Joint Network
  1527. joint_network = JointNetwork(
  1528. vocab_size,
  1529. encoder_output_size,
  1530. decoder_output_size,
  1531. **args.joint_network_conf,
  1532. )
  1533. # 7. Build model
  1534. if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
  1535. model = UnifiedTransducerModel(
  1536. vocab_size=vocab_size,
  1537. token_list=token_list,
  1538. frontend=frontend,
  1539. specaug=specaug,
  1540. normalize=normalize,
  1541. encoder=encoder,
  1542. decoder=decoder,
  1543. att_decoder=att_decoder,
  1544. joint_network=joint_network,
  1545. **args.model_conf,
  1546. )
  1547. else:
  1548. model = TransducerModel(
  1549. vocab_size=vocab_size,
  1550. token_list=token_list,
  1551. frontend=frontend,
  1552. specaug=specaug,
  1553. normalize=normalize,
  1554. encoder=encoder,
  1555. decoder=decoder,
  1556. att_decoder=att_decoder,
  1557. joint_network=joint_network,
  1558. **args.model_conf,
  1559. )
  1560. # 8. Initialize model
  1561. if args.init is not None:
  1562. raise NotImplementedError(
  1563. "Currently not supported.",
  1564. "Initialization part will be reworked in a short future.",
  1565. )
  1566. #assert check_return_type(model)
  1567. return model