asr.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.ctc import CTC
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.decoder.rnn_decoder import RNNDecoder
  25. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  26. from funasr.models.decoder.transformer_decoder import (
  27. DynamicConvolution2DTransformerDecoder, # noqa: H301
  28. )
  29. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolution2DTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import (
  34. LightweightConvolutionTransformerDecoder, # noqa: H301
  35. )
  36. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  37. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  38. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  39. from funasr.models.decoder.rnnt_decoder import RNNTDecoder
  40. from funasr.models.joint_net.joint_network import JointNetwork
  41. from funasr.models.e2e_asr import ESPnetASRModel
  42. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  43. from funasr.models.e2e_tp import TimestampPredictor
  44. from funasr.models.e2e_asr_mfcca import MFCCA
  45. from funasr.models.e2e_uni_asr import UniASR
  46. from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
  47. from funasr.models.encoder.abs_encoder import AbsEncoder
  48. from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
  49. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  50. from funasr.models.encoder.rnn_encoder import RNNEncoder
  51. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  52. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  53. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  54. from funasr.models.frontend.abs_frontend import AbsFrontend
  55. from funasr.models.frontend.default import DefaultFrontend
  56. from funasr.models.frontend.default import MultiChannelFrontend
  57. from funasr.models.frontend.fused import FusedFrontends
  58. from funasr.models.frontend.s3prl import S3prlFrontend
  59. from funasr.models.frontend.wav_frontend import WavFrontend
  60. from funasr.models.frontend.windowing import SlidingWindow
  61. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  62. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  63. HuggingFaceTransformersPostEncoder, # noqa: H301
  64. )
  65. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  66. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  67. from funasr.models.preencoder.linear import LinearProjection
  68. from funasr.models.preencoder.sinc import LightweightSincConvs
  69. from funasr.models.specaug.abs_specaug import AbsSpecAug
  70. from funasr.models.specaug.specaug import SpecAug
  71. from funasr.models.specaug.specaug import SpecAugLFR
  72. from funasr.modules.subsampling import Conv1dSubsampling
  73. from funasr.tasks.abs_task import AbsTask
  74. from funasr.text.phoneme_tokenizer import g2p_choices
  75. from funasr.torch_utils.initialize import initialize
  76. from funasr.train.abs_espnet_model import AbsESPnetModel
  77. from funasr.train.class_choices import ClassChoices
  78. from funasr.train.trainer import Trainer
  79. from funasr.utils.get_default_kwargs import get_default_kwargs
  80. from funasr.utils.nested_dict_action import NestedDictAction
  81. from funasr.utils.types import float_or_none
  82. from funasr.utils.types import int_or_none
  83. from funasr.utils.types import str2bool
  84. from funasr.utils.types import str_or_none
  85. frontend_choices = ClassChoices(
  86. name="frontend",
  87. classes=dict(
  88. default=DefaultFrontend,
  89. sliding_window=SlidingWindow,
  90. s3prl=S3prlFrontend,
  91. fused=FusedFrontends,
  92. wav_frontend=WavFrontend,
  93. multichannelfrontend=MultiChannelFrontend,
  94. ),
  95. type_check=AbsFrontend,
  96. default="default",
  97. )
  98. specaug_choices = ClassChoices(
  99. name="specaug",
  100. classes=dict(
  101. specaug=SpecAug,
  102. specaug_lfr=SpecAugLFR,
  103. ),
  104. type_check=AbsSpecAug,
  105. default=None,
  106. optional=True,
  107. )
  108. normalize_choices = ClassChoices(
  109. "normalize",
  110. classes=dict(
  111. global_mvn=GlobalMVN,
  112. utterance_mvn=UtteranceMVN,
  113. ),
  114. type_check=AbsNormalize,
  115. default=None,
  116. optional=True,
  117. )
  118. model_choices = ClassChoices(
  119. "model",
  120. classes=dict(
  121. asr=ESPnetASRModel,
  122. uniasr=UniASR,
  123. paraformer=Paraformer,
  124. paraformer_online=ParaformerOnline,
  125. paraformer_bert=ParaformerBert,
  126. bicif_paraformer=BiCifParaformer,
  127. contextual_paraformer=ContextualParaformer,
  128. mfcca=MFCCA,
  129. timestamp_prediction=TimestampPredictor,
  130. ),
  131. type_check=AbsESPnetModel,
  132. default="asr",
  133. )
  134. preencoder_choices = ClassChoices(
  135. name="preencoder",
  136. classes=dict(
  137. sinc=LightweightSincConvs,
  138. linear=LinearProjection,
  139. ),
  140. type_check=AbsPreEncoder,
  141. default=None,
  142. optional=True,
  143. )
  144. encoder_choices = ClassChoices(
  145. "encoder",
  146. classes=dict(
  147. conformer=ConformerEncoder,
  148. transformer=TransformerEncoder,
  149. rnn=RNNEncoder,
  150. sanm=SANMEncoder,
  151. sanm_chunk_opt=SANMEncoderChunkOpt,
  152. data2vec_encoder=Data2VecEncoder,
  153. mfcca_enc=MFCCAEncoder,
  154. chunk_conformer=ConformerChunkEncoder,
  155. ),
  156. type_check=AbsEncoder,
  157. default="rnn",
  158. )
  159. encoder_choices2 = ClassChoices(
  160. "encoder2",
  161. classes=dict(
  162. conformer=ConformerEncoder,
  163. transformer=TransformerEncoder,
  164. rnn=RNNEncoder,
  165. sanm=SANMEncoder,
  166. sanm_chunk_opt=SANMEncoderChunkOpt,
  167. ),
  168. type_check=AbsEncoder,
  169. default="rnn",
  170. )
  171. postencoder_choices = ClassChoices(
  172. name="postencoder",
  173. classes=dict(
  174. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  175. ),
  176. type_check=AbsPostEncoder,
  177. default=None,
  178. optional=True,
  179. )
  180. decoder_choices = ClassChoices(
  181. "decoder",
  182. classes=dict(
  183. transformer=TransformerDecoder,
  184. lightweight_conv=LightweightConvolutionTransformerDecoder,
  185. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  186. dynamic_conv=DynamicConvolutionTransformerDecoder,
  187. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  188. rnn=RNNDecoder,
  189. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  190. paraformer_decoder_sanm=ParaformerSANMDecoder,
  191. paraformer_decoder_san=ParaformerDecoderSAN,
  192. contextual_paraformer_decoder=ContextualParaformerDecoder,
  193. ),
  194. type_check=AbsDecoder,
  195. default="rnn",
  196. )
  197. decoder_choices2 = ClassChoices(
  198. "decoder2",
  199. classes=dict(
  200. transformer=TransformerDecoder,
  201. lightweight_conv=LightweightConvolutionTransformerDecoder,
  202. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  203. dynamic_conv=DynamicConvolutionTransformerDecoder,
  204. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  205. rnn=RNNDecoder,
  206. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  207. paraformer_decoder_sanm=ParaformerSANMDecoder,
  208. ),
  209. type_check=AbsDecoder,
  210. default="rnn",
  211. )
  212. rnnt_decoder_choices = ClassChoices(
  213. "rnnt_decoder",
  214. classes=dict(
  215. rnnt=RNNTDecoder,
  216. ),
  217. type_check=RNNTDecoder,
  218. default="rnnt",
  219. )
  220. predictor_choices = ClassChoices(
  221. name="predictor",
  222. classes=dict(
  223. cif_predictor=CifPredictor,
  224. ctc_predictor=None,
  225. cif_predictor_v2=CifPredictorV2,
  226. cif_predictor_v3=CifPredictorV3,
  227. ),
  228. type_check=None,
  229. default="cif_predictor",
  230. optional=True,
  231. )
  232. predictor_choices2 = ClassChoices(
  233. name="predictor2",
  234. classes=dict(
  235. cif_predictor=CifPredictor,
  236. ctc_predictor=None,
  237. cif_predictor_v2=CifPredictorV2,
  238. ),
  239. type_check=None,
  240. default="cif_predictor",
  241. optional=True,
  242. )
  243. stride_conv_choices = ClassChoices(
  244. name="stride_conv",
  245. classes=dict(
  246. stride_conv1d=Conv1dSubsampling
  247. ),
  248. type_check=None,
  249. default="stride_conv1d",
  250. optional=True,
  251. )
  252. class ASRTask(AbsTask):
  253. # If you need more than one optimizers, change this value
  254. num_optimizers: int = 1
  255. # Add variable objects configurations
  256. class_choices_list = [
  257. # --frontend and --frontend_conf
  258. frontend_choices,
  259. # --specaug and --specaug_conf
  260. specaug_choices,
  261. # --normalize and --normalize_conf
  262. normalize_choices,
  263. # --model and --model_conf
  264. model_choices,
  265. # --preencoder and --preencoder_conf
  266. preencoder_choices,
  267. # --encoder and --encoder_conf
  268. encoder_choices,
  269. # --postencoder and --postencoder_conf
  270. postencoder_choices,
  271. # --decoder and --decoder_conf
  272. decoder_choices,
  273. ]
  274. # If you need to modify train() or eval() procedures, change Trainer class here
  275. trainer = Trainer
  276. @classmethod
  277. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  278. group = parser.add_argument_group(description="Task related")
  279. # NOTE(kamo): add_arguments(..., required=True) can't be used
  280. # to provide --print_config mode. Instead of it, do as
  281. # required = parser.get_default("required")
  282. # required += ["token_list"]
  283. group.add_argument(
  284. "--token_list",
  285. type=str_or_none,
  286. default=None,
  287. help="A text mapping int-id to token",
  288. )
  289. group.add_argument(
  290. "--split_with_space",
  291. type=str2bool,
  292. default=True,
  293. help="whether to split text using <space>",
  294. )
  295. group.add_argument(
  296. "--seg_dict_file",
  297. type=str,
  298. default=None,
  299. help="seg_dict_file for text processing",
  300. )
  301. group.add_argument(
  302. "--init",
  303. type=lambda x: str_or_none(x.lower()),
  304. default=None,
  305. help="The initialization method",
  306. choices=[
  307. "chainer",
  308. "xavier_uniform",
  309. "xavier_normal",
  310. "kaiming_uniform",
  311. "kaiming_normal",
  312. None,
  313. ],
  314. )
  315. group.add_argument(
  316. "--input_size",
  317. type=int_or_none,
  318. default=None,
  319. help="The number of input dimension of the feature",
  320. )
  321. group.add_argument(
  322. "--ctc_conf",
  323. action=NestedDictAction,
  324. default=get_default_kwargs(CTC),
  325. help="The keyword arguments for CTC class.",
  326. )
  327. group.add_argument(
  328. "--joint_net_conf",
  329. action=NestedDictAction,
  330. default=None,
  331. help="The keyword arguments for joint network class.",
  332. )
  333. group = parser.add_argument_group(description="Preprocess related")
  334. group.add_argument(
  335. "--use_preprocessor",
  336. type=str2bool,
  337. default=True,
  338. help="Apply preprocessing to data or not",
  339. )
  340. group.add_argument(
  341. "--token_type",
  342. type=str,
  343. default="bpe",
  344. choices=["bpe", "char", "word", "phn"],
  345. help="The text will be tokenized " "in the specified level token",
  346. )
  347. group.add_argument(
  348. "--bpemodel",
  349. type=str_or_none,
  350. default=None,
  351. help="The model file of sentencepiece",
  352. )
  353. parser.add_argument(
  354. "--non_linguistic_symbols",
  355. type=str_or_none,
  356. default=None,
  357. help="non_linguistic_symbols file path",
  358. )
  359. parser.add_argument(
  360. "--cleaner",
  361. type=str_or_none,
  362. choices=[None, "tacotron", "jaconv", "vietnamese"],
  363. default=None,
  364. help="Apply text cleaning",
  365. )
  366. parser.add_argument(
  367. "--g2p",
  368. type=str_or_none,
  369. choices=g2p_choices,
  370. default=None,
  371. help="Specify g2p method if --token_type=phn",
  372. )
  373. parser.add_argument(
  374. "--speech_volume_normalize",
  375. type=float_or_none,
  376. default=None,
  377. help="Scale the maximum amplitude to the given value.",
  378. )
  379. parser.add_argument(
  380. "--rir_scp",
  381. type=str_or_none,
  382. default=None,
  383. help="The file path of rir scp file.",
  384. )
  385. parser.add_argument(
  386. "--rir_apply_prob",
  387. type=float,
  388. default=1.0,
  389. help="THe probability for applying RIR convolution.",
  390. )
  391. parser.add_argument(
  392. "--cmvn_file",
  393. type=str_or_none,
  394. default=None,
  395. help="The file path of noise scp file.",
  396. )
  397. parser.add_argument(
  398. "--noise_scp",
  399. type=str_or_none,
  400. default=None,
  401. help="The file path of noise scp file.",
  402. )
  403. parser.add_argument(
  404. "--noise_apply_prob",
  405. type=float,
  406. default=1.0,
  407. help="The probability applying Noise adding.",
  408. )
  409. parser.add_argument(
  410. "--noise_db_range",
  411. type=str,
  412. default="13_15",
  413. help="The range of noise decibel level.",
  414. )
  415. for class_choices in cls.class_choices_list:
  416. # Append --<name> and --<name>_conf.
  417. # e.g. --encoder and --encoder_conf
  418. class_choices.add_arguments(group)
  419. @classmethod
  420. def build_collate_fn(
  421. cls, args: argparse.Namespace, train: bool
  422. ) -> Callable[
  423. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  424. Tuple[List[str], Dict[str, torch.Tensor]],
  425. ]:
  426. assert check_argument_types()
  427. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  428. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  429. @classmethod
  430. def build_preprocess_fn(
  431. cls, args: argparse.Namespace, train: bool
  432. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  433. assert check_argument_types()
  434. if args.use_preprocessor:
  435. retval = CommonPreprocessor(
  436. train=train,
  437. token_type=args.token_type,
  438. token_list=args.token_list,
  439. bpemodel=args.bpemodel,
  440. non_linguistic_symbols=args.non_linguistic_symbols,
  441. text_cleaner=args.cleaner,
  442. g2p_type=args.g2p,
  443. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  444. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  445. # NOTE(kamo): Check attribute existence for backward compatibility
  446. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  447. rir_apply_prob=args.rir_apply_prob
  448. if hasattr(args, "rir_apply_prob")
  449. else 1.0,
  450. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  451. noise_apply_prob=args.noise_apply_prob
  452. if hasattr(args, "noise_apply_prob")
  453. else 1.0,
  454. noise_db_range=args.noise_db_range
  455. if hasattr(args, "noise_db_range")
  456. else "13_15",
  457. speech_volume_normalize=args.speech_volume_normalize
  458. if hasattr(args, "rir_scp")
  459. else None,
  460. )
  461. else:
  462. retval = None
  463. assert check_return_type(retval)
  464. return retval
  465. @classmethod
  466. def required_data_names(
  467. cls, train: bool = True, inference: bool = False
  468. ) -> Tuple[str, ...]:
  469. if not inference:
  470. retval = ("speech", "text")
  471. else:
  472. # Recognition mode
  473. retval = ("speech",)
  474. return retval
  475. @classmethod
  476. def optional_data_names(
  477. cls, train: bool = True, inference: bool = False
  478. ) -> Tuple[str, ...]:
  479. retval = ()
  480. assert check_return_type(retval)
  481. return retval
  482. @classmethod
  483. def build_model(cls, args: argparse.Namespace):
  484. assert check_argument_types()
  485. if isinstance(args.token_list, str):
  486. with open(args.token_list, encoding="utf-8") as f:
  487. token_list = [line.rstrip() for line in f]
  488. # Overwriting token_list to keep it as "portable".
  489. args.token_list = list(token_list)
  490. elif isinstance(args.token_list, (tuple, list)):
  491. token_list = list(args.token_list)
  492. else:
  493. raise RuntimeError("token_list must be str or list")
  494. vocab_size = len(token_list)
  495. logging.info(f"Vocabulary size: {vocab_size}")
  496. # 1. frontend
  497. if args.input_size is None:
  498. # Extract features in the model
  499. frontend_class = frontend_choices.get_class(args.frontend)
  500. if args.frontend == 'wav_frontend':
  501. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  502. else:
  503. frontend = frontend_class(**args.frontend_conf)
  504. input_size = frontend.output_size()
  505. else:
  506. # Give features from data-loader
  507. args.frontend = None
  508. args.frontend_conf = {}
  509. frontend = None
  510. input_size = args.input_size
  511. # 2. Data augmentation for spectrogram
  512. if args.specaug is not None:
  513. specaug_class = specaug_choices.get_class(args.specaug)
  514. specaug = specaug_class(**args.specaug_conf)
  515. else:
  516. specaug = None
  517. # 3. Normalization layer
  518. if args.normalize is not None:
  519. normalize_class = normalize_choices.get_class(args.normalize)
  520. normalize = normalize_class(**args.normalize_conf)
  521. else:
  522. normalize = None
  523. # 4. Pre-encoder input block
  524. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  525. if getattr(args, "preencoder", None) is not None:
  526. preencoder_class = preencoder_choices.get_class(args.preencoder)
  527. preencoder = preencoder_class(**args.preencoder_conf)
  528. input_size = preencoder.output_size()
  529. else:
  530. preencoder = None
  531. # 5. Encoder
  532. encoder_class = encoder_choices.get_class(args.encoder)
  533. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  534. # 6. Post-encoder block
  535. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  536. encoder_output_size = encoder.output_size()
  537. if getattr(args, "postencoder", None) is not None:
  538. postencoder_class = postencoder_choices.get_class(args.postencoder)
  539. postencoder = postencoder_class(
  540. input_size=encoder_output_size, **args.postencoder_conf
  541. )
  542. encoder_output_size = postencoder.output_size()
  543. else:
  544. postencoder = None
  545. # 7. Decoder
  546. decoder_class = decoder_choices.get_class(args.decoder)
  547. decoder = decoder_class(
  548. vocab_size=vocab_size,
  549. encoder_output_size=encoder_output_size,
  550. **args.decoder_conf,
  551. )
  552. # 8. CTC
  553. ctc = CTC(
  554. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  555. )
  556. # 9. Build model
  557. try:
  558. model_class = model_choices.get_class(args.model)
  559. except AttributeError:
  560. model_class = model_choices.get_class("asr")
  561. model = model_class(
  562. vocab_size=vocab_size,
  563. frontend=frontend,
  564. specaug=specaug,
  565. normalize=normalize,
  566. preencoder=preencoder,
  567. encoder=encoder,
  568. postencoder=postencoder,
  569. decoder=decoder,
  570. ctc=ctc,
  571. token_list=token_list,
  572. **args.model_conf,
  573. )
  574. # 10. Initialize
  575. if args.init is not None:
  576. initialize(model, args.init)
  577. assert check_return_type(model)
  578. return model
  579. class ASRTaskUniASR(ASRTask):
  580. # If you need more than one optimizers, change this value
  581. num_optimizers: int = 1
  582. # Add variable objects configurations
  583. class_choices_list = [
  584. # --frontend and --frontend_conf
  585. frontend_choices,
  586. # --specaug and --specaug_conf
  587. specaug_choices,
  588. # --normalize and --normalize_conf
  589. normalize_choices,
  590. # --model and --model_conf
  591. model_choices,
  592. # --preencoder and --preencoder_conf
  593. preencoder_choices,
  594. # --encoder and --encoder_conf
  595. encoder_choices,
  596. # --postencoder and --postencoder_conf
  597. postencoder_choices,
  598. # --decoder and --decoder_conf
  599. decoder_choices,
  600. # --predictor and --predictor_conf
  601. predictor_choices,
  602. # --encoder2 and --encoder2_conf
  603. encoder_choices2,
  604. # --decoder2 and --decoder2_conf
  605. decoder_choices2,
  606. # --predictor2 and --predictor2_conf
  607. predictor_choices2,
  608. # --stride_conv and --stride_conv_conf
  609. stride_conv_choices,
  610. ]
  611. # If you need to modify train() or eval() procedures, change Trainer class here
  612. trainer = Trainer
  613. @classmethod
  614. def build_model(cls, args: argparse.Namespace):
  615. assert check_argument_types()
  616. if isinstance(args.token_list, str):
  617. with open(args.token_list, encoding="utf-8") as f:
  618. token_list = [line.rstrip() for line in f]
  619. # Overwriting token_list to keep it as "portable".
  620. args.token_list = list(token_list)
  621. elif isinstance(args.token_list, (tuple, list)):
  622. token_list = list(args.token_list)
  623. else:
  624. raise RuntimeError("token_list must be str or list")
  625. vocab_size = len(token_list)
  626. logging.info(f"Vocabulary size: {vocab_size}")
  627. # 1. frontend
  628. if args.input_size is None:
  629. # Extract features in the model
  630. frontend_class = frontend_choices.get_class(args.frontend)
  631. if args.frontend == 'wav_frontend':
  632. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  633. else:
  634. frontend = frontend_class(**args.frontend_conf)
  635. input_size = frontend.output_size()
  636. else:
  637. # Give features from data-loader
  638. args.frontend = None
  639. args.frontend_conf = {}
  640. frontend = None
  641. input_size = args.input_size
  642. # 2. Data augmentation for spectrogram
  643. if args.specaug is not None:
  644. specaug_class = specaug_choices.get_class(args.specaug)
  645. specaug = specaug_class(**args.specaug_conf)
  646. else:
  647. specaug = None
  648. # 3. Normalization layer
  649. if args.normalize is not None:
  650. normalize_class = normalize_choices.get_class(args.normalize)
  651. normalize = normalize_class(**args.normalize_conf)
  652. else:
  653. normalize = None
  654. # 4. Pre-encoder input block
  655. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  656. if getattr(args, "preencoder", None) is not None:
  657. preencoder_class = preencoder_choices.get_class(args.preencoder)
  658. preencoder = preencoder_class(**args.preencoder_conf)
  659. input_size = preencoder.output_size()
  660. else:
  661. preencoder = None
  662. # 5. Encoder
  663. encoder_class = encoder_choices.get_class(args.encoder)
  664. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  665. encoder_output_size = encoder.output_size()
  666. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  667. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder_output_size,
  668. odim=input_size + encoder_output_size)
  669. stride_conv_output_size = stride_conv.output_size()
  670. # 6. Encoder2
  671. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  672. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  673. # 7. Post-encoder block
  674. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  675. encoder_output_size2 = encoder2.output_size()
  676. if getattr(args, "postencoder", None) is not None:
  677. postencoder_class = postencoder_choices.get_class(args.postencoder)
  678. postencoder = postencoder_class(
  679. input_size=encoder_output_size, **args.postencoder_conf
  680. )
  681. encoder_output_size = postencoder.output_size()
  682. else:
  683. postencoder = None
  684. # 8. Decoder & Decoder2
  685. decoder_class = decoder_choices.get_class(args.decoder)
  686. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  687. decoder = decoder_class(
  688. vocab_size=vocab_size,
  689. encoder_output_size=encoder_output_size,
  690. **args.decoder_conf,
  691. )
  692. decoder2 = decoder_class2(
  693. vocab_size=vocab_size,
  694. encoder_output_size=encoder_output_size2,
  695. **args.decoder2_conf,
  696. )
  697. # 9. CTC
  698. ctc = CTC(
  699. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  700. )
  701. ctc2 = CTC(
  702. odim=vocab_size, encoder_output_size=encoder_output_size2, **args.ctc_conf
  703. )
  704. # 10. Predictor
  705. predictor_class = predictor_choices.get_class(args.predictor)
  706. predictor = predictor_class(**args.predictor_conf)
  707. predictor_class = predictor_choices2.get_class(args.predictor2)
  708. predictor2 = predictor_class(**args.predictor2_conf)
  709. # 11. Build model
  710. try:
  711. model_class = model_choices.get_class(args.model)
  712. except AttributeError:
  713. model_class = model_choices.get_class("asr")
  714. model = model_class(
  715. vocab_size=vocab_size,
  716. frontend=frontend,
  717. specaug=specaug,
  718. normalize=normalize,
  719. preencoder=preencoder,
  720. encoder=encoder,
  721. postencoder=postencoder,
  722. decoder=decoder,
  723. ctc=ctc,
  724. token_list=token_list,
  725. predictor=predictor,
  726. ctc2=ctc2,
  727. encoder2=encoder2,
  728. decoder2=decoder2,
  729. predictor2=predictor2,
  730. stride_conv=stride_conv,
  731. **args.model_conf,
  732. )
  733. # 12. Initialize
  734. if args.init is not None:
  735. initialize(model, args.init)
  736. assert check_return_type(model)
  737. return model
  738. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  739. @classmethod
  740. def build_model_from_file(
  741. cls,
  742. config_file: Union[Path, str] = None,
  743. model_file: Union[Path, str] = None,
  744. cmvn_file: Union[Path, str] = None,
  745. device: str = "cpu",
  746. ):
  747. """Build model from the files.
  748. This method is used for inference or fine-tuning.
  749. Args:
  750. config_file: The yaml file saved when training.
  751. model_file: The model file saved when training.
  752. device: Device type, "cpu", "cuda", or "cuda:N".
  753. """
  754. assert check_argument_types()
  755. if config_file is None:
  756. assert model_file is not None, (
  757. "The argument 'model_file' must be provided "
  758. "if the argument 'config_file' is not specified."
  759. )
  760. config_file = Path(model_file).parent / "config.yaml"
  761. else:
  762. config_file = Path(config_file)
  763. with config_file.open("r", encoding="utf-8") as f:
  764. args = yaml.safe_load(f)
  765. if cmvn_file is not None:
  766. args["cmvn_file"] = cmvn_file
  767. args = argparse.Namespace(**args)
  768. model = cls.build_model(args)
  769. if not isinstance(model, AbsESPnetModel):
  770. raise RuntimeError(
  771. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  772. )
  773. model.to(device)
  774. model_dict = dict()
  775. model_name_pth = None
  776. if model_file is not None:
  777. logging.info("model_file is {}".format(model_file))
  778. if device == "cuda":
  779. device = f"cuda:{torch.cuda.current_device()}"
  780. model_dir = os.path.dirname(model_file)
  781. model_name = os.path.basename(model_file)
  782. if "model.ckpt-" in model_name or ".bin" in model_name:
  783. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  784. '.pb')) if ".bin" in model_name else os.path.join(
  785. model_dir, "{}.pb".format(model_name))
  786. if os.path.exists(model_name_pth):
  787. logging.info("model_file is load from pth: {}".format(model_name_pth))
  788. model_dict = torch.load(model_name_pth, map_location=device)
  789. else:
  790. model_dict = cls.convert_tf2torch(model, model_file)
  791. model.load_state_dict(model_dict)
  792. else:
  793. model_dict = torch.load(model_file, map_location=device)
  794. model.load_state_dict(model_dict)
  795. if model_name_pth is not None and not os.path.exists(model_name_pth):
  796. torch.save(model_dict, model_name_pth)
  797. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  798. return model, args
  799. @classmethod
  800. def convert_tf2torch(
  801. cls,
  802. model,
  803. ckpt,
  804. ):
  805. logging.info("start convert tf model to torch model")
  806. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  807. var_dict_tf = load_tf_dict(ckpt)
  808. var_dict_torch = model.state_dict()
  809. var_dict_torch_update = dict()
  810. # encoder
  811. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  812. var_dict_torch_update.update(var_dict_torch_update_local)
  813. # predictor
  814. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  815. var_dict_torch_update.update(var_dict_torch_update_local)
  816. # decoder
  817. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  818. var_dict_torch_update.update(var_dict_torch_update_local)
  819. # encoder2
  820. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  821. var_dict_torch_update.update(var_dict_torch_update_local)
  822. # predictor2
  823. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  824. var_dict_torch_update.update(var_dict_torch_update_local)
  825. # decoder2
  826. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  827. var_dict_torch_update.update(var_dict_torch_update_local)
  828. # stride_conv
  829. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  830. var_dict_torch_update.update(var_dict_torch_update_local)
  831. return var_dict_torch_update
  832. class ASRTaskParaformer(ASRTask):
  833. # If you need more than one optimizers, change this value
  834. num_optimizers: int = 1
  835. # Add variable objects configurations
  836. class_choices_list = [
  837. # --frontend and --frontend_conf
  838. frontend_choices,
  839. # --specaug and --specaug_conf
  840. specaug_choices,
  841. # --normalize and --normalize_conf
  842. normalize_choices,
  843. # --model and --model_conf
  844. model_choices,
  845. # --preencoder and --preencoder_conf
  846. preencoder_choices,
  847. # --encoder and --encoder_conf
  848. encoder_choices,
  849. # --postencoder and --postencoder_conf
  850. postencoder_choices,
  851. # --decoder and --decoder_conf
  852. decoder_choices,
  853. # --predictor and --predictor_conf
  854. predictor_choices,
  855. ]
  856. # If you need to modify train() or eval() procedures, change Trainer class here
  857. trainer = Trainer
  858. @classmethod
  859. def build_model(cls, args: argparse.Namespace):
  860. assert check_argument_types()
  861. if isinstance(args.token_list, str):
  862. with open(args.token_list, encoding="utf-8") as f:
  863. token_list = [line.rstrip() for line in f]
  864. # Overwriting token_list to keep it as "portable".
  865. args.token_list = list(token_list)
  866. elif isinstance(args.token_list, (tuple, list)):
  867. token_list = list(args.token_list)
  868. else:
  869. raise RuntimeError("token_list must be str or list")
  870. vocab_size = len(token_list)
  871. logging.info(f"Vocabulary size: {vocab_size}")
  872. # 1. frontend
  873. if args.input_size is None:
  874. # Extract features in the model
  875. frontend_class = frontend_choices.get_class(args.frontend)
  876. if args.frontend == 'wav_frontend':
  877. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  878. else:
  879. frontend = frontend_class(**args.frontend_conf)
  880. input_size = frontend.output_size()
  881. else:
  882. # Give features from data-loader
  883. args.frontend = None
  884. args.frontend_conf = {}
  885. frontend = None
  886. input_size = args.input_size
  887. # 2. Data augmentation for spectrogram
  888. if args.specaug is not None:
  889. specaug_class = specaug_choices.get_class(args.specaug)
  890. specaug = specaug_class(**args.specaug_conf)
  891. else:
  892. specaug = None
  893. # 3. Normalization layer
  894. if args.normalize is not None:
  895. normalize_class = normalize_choices.get_class(args.normalize)
  896. normalize = normalize_class(**args.normalize_conf)
  897. else:
  898. normalize = None
  899. # 4. Pre-encoder input block
  900. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  901. if getattr(args, "preencoder", None) is not None:
  902. preencoder_class = preencoder_choices.get_class(args.preencoder)
  903. preencoder = preencoder_class(**args.preencoder_conf)
  904. input_size = preencoder.output_size()
  905. else:
  906. preencoder = None
  907. # 5. Encoder
  908. encoder_class = encoder_choices.get_class(args.encoder)
  909. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  910. # 6. Post-encoder block
  911. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  912. encoder_output_size = encoder.output_size()
  913. if getattr(args, "postencoder", None) is not None:
  914. postencoder_class = postencoder_choices.get_class(args.postencoder)
  915. postencoder = postencoder_class(
  916. input_size=encoder_output_size, **args.postencoder_conf
  917. )
  918. encoder_output_size = postencoder.output_size()
  919. else:
  920. postencoder = None
  921. # 7. Decoder
  922. decoder_class = decoder_choices.get_class(args.decoder)
  923. decoder = decoder_class(
  924. vocab_size=vocab_size,
  925. encoder_output_size=encoder_output_size,
  926. **args.decoder_conf,
  927. )
  928. # 8. CTC
  929. ctc = CTC(
  930. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  931. )
  932. # 9. Predictor
  933. predictor_class = predictor_choices.get_class(args.predictor)
  934. predictor = predictor_class(**args.predictor_conf)
  935. # 10. Build model
  936. try:
  937. model_class = model_choices.get_class(args.model)
  938. except AttributeError:
  939. model_class = model_choices.get_class("asr")
  940. model = model_class(
  941. vocab_size=vocab_size,
  942. frontend=frontend,
  943. specaug=specaug,
  944. normalize=normalize,
  945. preencoder=preencoder,
  946. encoder=encoder,
  947. postencoder=postencoder,
  948. decoder=decoder,
  949. ctc=ctc,
  950. token_list=token_list,
  951. predictor=predictor,
  952. **args.model_conf,
  953. )
  954. # 11. Initialize
  955. if args.init is not None:
  956. initialize(model, args.init)
  957. assert check_return_type(model)
  958. return model
  959. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  960. @classmethod
  961. def build_model_from_file(
  962. cls,
  963. config_file: Union[Path, str] = None,
  964. model_file: Union[Path, str] = None,
  965. cmvn_file: Union[Path, str] = None,
  966. device: str = "cpu",
  967. ):
  968. """Build model from the files.
  969. This method is used for inference or fine-tuning.
  970. Args:
  971. config_file: The yaml file saved when training.
  972. model_file: The model file saved when training.
  973. device: Device type, "cpu", "cuda", or "cuda:N".
  974. """
  975. assert check_argument_types()
  976. if config_file is None:
  977. assert model_file is not None, (
  978. "The argument 'model_file' must be provided "
  979. "if the argument 'config_file' is not specified."
  980. )
  981. config_file = Path(model_file).parent / "config.yaml"
  982. else:
  983. config_file = Path(config_file)
  984. with config_file.open("r", encoding="utf-8") as f:
  985. args = yaml.safe_load(f)
  986. if cmvn_file is not None:
  987. args["cmvn_file"] = cmvn_file
  988. args = argparse.Namespace(**args)
  989. model = cls.build_model(args)
  990. if not isinstance(model, AbsESPnetModel):
  991. raise RuntimeError(
  992. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  993. )
  994. model.to(device)
  995. model_dict = dict()
  996. model_name_pth = None
  997. if model_file is not None:
  998. logging.info("model_file is {}".format(model_file))
  999. if device == "cuda":
  1000. device = f"cuda:{torch.cuda.current_device()}"
  1001. model_dir = os.path.dirname(model_file)
  1002. model_name = os.path.basename(model_file)
  1003. if "model.ckpt-" in model_name or ".bin" in model_name:
  1004. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  1005. '.pb')) if ".bin" in model_name else os.path.join(
  1006. model_dir, "{}.pb".format(model_name))
  1007. if os.path.exists(model_name_pth):
  1008. logging.info("model_file is load from pth: {}".format(model_name_pth))
  1009. model_dict = torch.load(model_name_pth, map_location=device)
  1010. else:
  1011. model_dict = cls.convert_tf2torch(model, model_file)
  1012. model.load_state_dict(model_dict)
  1013. else:
  1014. model_dict = torch.load(model_file, map_location=device)
  1015. model.load_state_dict(model_dict)
  1016. if model_name_pth is not None and not os.path.exists(model_name_pth):
  1017. torch.save(model_dict, model_name_pth)
  1018. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  1019. model.to(device)
  1020. return model, args
  1021. @classmethod
  1022. def convert_tf2torch(
  1023. cls,
  1024. model,
  1025. ckpt,
  1026. ):
  1027. logging.info("start convert tf model to torch model")
  1028. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  1029. var_dict_tf = load_tf_dict(ckpt)
  1030. var_dict_torch = model.state_dict()
  1031. var_dict_torch_update = dict()
  1032. # encoder
  1033. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1034. var_dict_torch_update.update(var_dict_torch_update_local)
  1035. # predictor
  1036. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  1037. var_dict_torch_update.update(var_dict_torch_update_local)
  1038. # decoder
  1039. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1040. var_dict_torch_update.update(var_dict_torch_update_local)
  1041. # bias_encoder
  1042. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  1043. var_dict_torch_update.update(var_dict_torch_update_local)
  1044. return var_dict_torch_update
  1045. class ASRTaskMFCCA(ASRTask):
  1046. # If you need more than one optimizers, change this value
  1047. num_optimizers: int = 1
  1048. # Add variable objects configurations
  1049. class_choices_list = [
  1050. # --frontend and --frontend_conf
  1051. frontend_choices,
  1052. # --specaug and --specaug_conf
  1053. specaug_choices,
  1054. # --normalize and --normalize_conf
  1055. normalize_choices,
  1056. # --model and --model_conf
  1057. model_choices,
  1058. # --preencoder and --preencoder_conf
  1059. preencoder_choices,
  1060. # --encoder and --encoder_conf
  1061. encoder_choices,
  1062. # --decoder and --decoder_conf
  1063. decoder_choices,
  1064. ]
  1065. # If you need to modify train() or eval() procedures, change Trainer class here
  1066. trainer = Trainer
  1067. @classmethod
  1068. def build_model(cls, args: argparse.Namespace):
  1069. assert check_argument_types()
  1070. if isinstance(args.token_list, str):
  1071. with open(args.token_list, encoding="utf-8") as f:
  1072. token_list = [line.rstrip() for line in f]
  1073. # Overwriting token_list to keep it as "portable".
  1074. args.token_list = list(token_list)
  1075. elif isinstance(args.token_list, (tuple, list)):
  1076. token_list = list(args.token_list)
  1077. else:
  1078. raise RuntimeError("token_list must be str or list")
  1079. vocab_size = len(token_list)
  1080. logging.info(f"Vocabulary size: {vocab_size}")
  1081. # 1. frontend
  1082. if args.input_size is None:
  1083. # Extract features in the model
  1084. frontend_class = frontend_choices.get_class(args.frontend)
  1085. if args.frontend == 'wav_frontend':
  1086. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1087. else:
  1088. frontend = frontend_class(**args.frontend_conf)
  1089. input_size = frontend.output_size()
  1090. else:
  1091. # Give features from data-loader
  1092. args.frontend = None
  1093. args.frontend_conf = {}
  1094. frontend = None
  1095. input_size = args.input_size
  1096. # 2. Data augmentation for spectrogram
  1097. if args.specaug is not None:
  1098. specaug_class = specaug_choices.get_class(args.specaug)
  1099. specaug = specaug_class(**args.specaug_conf)
  1100. else:
  1101. specaug = None
  1102. # 3. Normalization layer
  1103. if args.normalize is not None:
  1104. normalize_class = normalize_choices.get_class(args.normalize)
  1105. normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
  1106. else:
  1107. normalize = None
  1108. # 4. Pre-encoder input block
  1109. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  1110. if getattr(args, "preencoder", None) is not None:
  1111. preencoder_class = preencoder_choices.get_class(args.preencoder)
  1112. preencoder = preencoder_class(**args.preencoder_conf)
  1113. input_size = preencoder.output_size()
  1114. else:
  1115. preencoder = None
  1116. # 5. Encoder
  1117. encoder_class = encoder_choices.get_class(args.encoder)
  1118. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1119. # 7. Decoder
  1120. decoder_class = decoder_choices.get_class(args.decoder)
  1121. decoder = decoder_class(
  1122. vocab_size=vocab_size,
  1123. encoder_output_size=encoder.output_size(),
  1124. **args.decoder_conf,
  1125. )
  1126. # 8. CTC
  1127. ctc = CTC(
  1128. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  1129. )
  1130. # 10. Build model
  1131. try:
  1132. model_class = model_choices.get_class(args.model)
  1133. except AttributeError:
  1134. model_class = model_choices.get_class("asr")
  1135. rnnt_decoder = None
  1136. # 8. Build model
  1137. model = model_class(
  1138. vocab_size=vocab_size,
  1139. frontend=frontend,
  1140. specaug=specaug,
  1141. normalize=normalize,
  1142. preencoder=preencoder,
  1143. encoder=encoder,
  1144. decoder=decoder,
  1145. ctc=ctc,
  1146. rnnt_decoder=rnnt_decoder,
  1147. token_list=token_list,
  1148. **args.model_conf,
  1149. )
  1150. # 11. Initialize
  1151. if args.init is not None:
  1152. initialize(model, args.init)
  1153. assert check_return_type(model)
  1154. return model
  1155. class ASRTaskAligner(ASRTaskParaformer):
  1156. # If you need more than one optimizers, change this value
  1157. num_optimizers: int = 1
  1158. # Add variable objects configurations
  1159. class_choices_list = [
  1160. # --frontend and --frontend_conf
  1161. frontend_choices,
  1162. # --model and --model_conf
  1163. model_choices,
  1164. # --encoder and --encoder_conf
  1165. encoder_choices,
  1166. # --decoder and --decoder_conf
  1167. decoder_choices,
  1168. ]
  1169. # If you need to modify train() or eval() procedures, change Trainer class here
  1170. trainer = Trainer
  1171. @classmethod
  1172. def build_model(cls, args: argparse.Namespace):
  1173. assert check_argument_types()
  1174. if isinstance(args.token_list, str):
  1175. with open(args.token_list, encoding="utf-8") as f:
  1176. token_list = [line.rstrip() for line in f]
  1177. # Overwriting token_list to keep it as "portable".
  1178. args.token_list = list(token_list)
  1179. elif isinstance(args.token_list, (tuple, list)):
  1180. token_list = list(args.token_list)
  1181. else:
  1182. raise RuntimeError("token_list must be str or list")
  1183. # 1. frontend
  1184. if args.input_size is None:
  1185. # Extract features in the model
  1186. frontend_class = frontend_choices.get_class(args.frontend)
  1187. if args.frontend == 'wav_frontend':
  1188. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1189. else:
  1190. frontend = frontend_class(**args.frontend_conf)
  1191. input_size = frontend.output_size()
  1192. else:
  1193. # Give features from data-loader
  1194. args.frontend = None
  1195. args.frontend_conf = {}
  1196. frontend = None
  1197. input_size = args.input_size
  1198. # 2. Encoder
  1199. encoder_class = encoder_choices.get_class(args.encoder)
  1200. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1201. # 3. Predictor
  1202. predictor_class = predictor_choices.get_class(args.predictor)
  1203. predictor = predictor_class(**args.predictor_conf)
  1204. # 10. Build model
  1205. try:
  1206. model_class = model_choices.get_class(args.model)
  1207. except AttributeError:
  1208. model_class = model_choices.get_class("asr")
  1209. # 8. Build model
  1210. model = model_class(
  1211. frontend=frontend,
  1212. encoder=encoder,
  1213. predictor=predictor,
  1214. token_list=token_list,
  1215. **args.model_conf,
  1216. )
  1217. # 11. Initialize
  1218. if args.init is not None:
  1219. initialize(model, args.init)
  1220. assert check_return_type(model)
  1221. return model
  1222. @classmethod
  1223. def required_data_names(
  1224. cls, train: bool = True, inference: bool = False
  1225. ) -> Tuple[str, ...]:
  1226. retval = ("speech", "text")
  1227. return retval
  1228. class ASRTransducerTask(AbsTask):
  1229. """ASR Transducer Task definition."""
  1230. num_optimizers: int = 1
  1231. class_choices_list = [
  1232. frontend_choices,
  1233. specaug_choices,
  1234. normalize_choices,
  1235. encoder_choices,
  1236. rnnt_decoder_choices,
  1237. ]
  1238. trainer = Trainer
  1239. @classmethod
  1240. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  1241. """Add Transducer task arguments.
  1242. Args:
  1243. cls: ASRTransducerTask object.
  1244. parser: Transducer arguments parser.
  1245. """
  1246. group = parser.add_argument_group(description="Task related.")
  1247. # required = parser.get_default("required")
  1248. # required += ["token_list"]
  1249. group.add_argument(
  1250. "--token_list",
  1251. type=str_or_none,
  1252. default=None,
  1253. help="Integer-string mapper for tokens.",
  1254. )
  1255. group.add_argument(
  1256. "--split_with_space",
  1257. type=str2bool,
  1258. default=True,
  1259. help="whether to split text using <space>",
  1260. )
  1261. group.add_argument(
  1262. "--input_size",
  1263. type=int_or_none,
  1264. default=None,
  1265. help="The number of dimensions for input features.",
  1266. )
  1267. group.add_argument(
  1268. "--init",
  1269. type=str_or_none,
  1270. default=None,
  1271. help="Type of model initialization to use.",
  1272. )
  1273. group.add_argument(
  1274. "--model_conf",
  1275. action=NestedDictAction,
  1276. default=get_default_kwargs(TransducerModel),
  1277. help="The keyword arguments for the model class.",
  1278. )
  1279. # group.add_argument(
  1280. # "--encoder_conf",
  1281. # action=NestedDictAction,
  1282. # default={},
  1283. # help="The keyword arguments for the encoder class.",
  1284. # )
  1285. group.add_argument(
  1286. "--joint_network_conf",
  1287. action=NestedDictAction,
  1288. default={},
  1289. help="The keyword arguments for the joint network class.",
  1290. )
  1291. group = parser.add_argument_group(description="Preprocess related.")
  1292. group.add_argument(
  1293. "--use_preprocessor",
  1294. type=str2bool,
  1295. default=True,
  1296. help="Whether to apply preprocessing to input data.",
  1297. )
  1298. group.add_argument(
  1299. "--token_type",
  1300. type=str,
  1301. default="bpe",
  1302. choices=["bpe", "char", "word", "phn"],
  1303. help="The type of tokens to use during tokenization.",
  1304. )
  1305. group.add_argument(
  1306. "--bpemodel",
  1307. type=str_or_none,
  1308. default=None,
  1309. help="The path of the sentencepiece model.",
  1310. )
  1311. parser.add_argument(
  1312. "--non_linguistic_symbols",
  1313. type=str_or_none,
  1314. help="The 'non_linguistic_symbols' file path.",
  1315. )
  1316. parser.add_argument(
  1317. "--cleaner",
  1318. type=str_or_none,
  1319. choices=[None, "tacotron", "jaconv", "vietnamese"],
  1320. default=None,
  1321. help="Text cleaner to use.",
  1322. )
  1323. parser.add_argument(
  1324. "--g2p",
  1325. type=str_or_none,
  1326. choices=g2p_choices,
  1327. default=None,
  1328. help="g2p method to use if --token_type=phn.",
  1329. )
  1330. parser.add_argument(
  1331. "--speech_volume_normalize",
  1332. type=float_or_none,
  1333. default=None,
  1334. help="Normalization value for maximum amplitude scaling.",
  1335. )
  1336. parser.add_argument(
  1337. "--rir_scp",
  1338. type=str_or_none,
  1339. default=None,
  1340. help="The RIR SCP file path.",
  1341. )
  1342. parser.add_argument(
  1343. "--rir_apply_prob",
  1344. type=float,
  1345. default=1.0,
  1346. help="The probability of the applied RIR convolution.",
  1347. )
  1348. parser.add_argument(
  1349. "--noise_scp",
  1350. type=str_or_none,
  1351. default=None,
  1352. help="The path of noise SCP file.",
  1353. )
  1354. parser.add_argument(
  1355. "--noise_apply_prob",
  1356. type=float,
  1357. default=1.0,
  1358. help="The probability of the applied noise addition.",
  1359. )
  1360. parser.add_argument(
  1361. "--noise_db_range",
  1362. type=str,
  1363. default="13_15",
  1364. help="The range of the noise decibel level.",
  1365. )
  1366. for class_choices in cls.class_choices_list:
  1367. # Append --<name> and --<name>_conf.
  1368. # e.g. --decoder and --decoder_conf
  1369. class_choices.add_arguments(group)
  1370. @classmethod
  1371. def build_collate_fn(
  1372. cls, args: argparse.Namespace, train: bool
  1373. ) -> Callable[
  1374. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  1375. Tuple[List[str], Dict[str, torch.Tensor]],
  1376. ]:
  1377. """Build collate function.
  1378. Args:
  1379. cls: ASRTransducerTask object.
  1380. args: Task arguments.
  1381. train: Training mode.
  1382. Return:
  1383. : Callable collate function.
  1384. """
  1385. assert check_argument_types()
  1386. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  1387. @classmethod
  1388. def build_preprocess_fn(
  1389. cls, args: argparse.Namespace, train: bool
  1390. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  1391. """Build pre-processing function.
  1392. Args:
  1393. cls: ASRTransducerTask object.
  1394. args: Task arguments.
  1395. train: Training mode.
  1396. Return:
  1397. : Callable pre-processing function.
  1398. """
  1399. assert check_argument_types()
  1400. if args.use_preprocessor:
  1401. retval = CommonPreprocessor(
  1402. train=train,
  1403. token_type=args.token_type,
  1404. token_list=args.token_list,
  1405. bpemodel=args.bpemodel,
  1406. non_linguistic_symbols=args.non_linguistic_symbols,
  1407. text_cleaner=args.cleaner,
  1408. g2p_type=args.g2p,
  1409. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  1410. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  1411. rir_apply_prob=args.rir_apply_prob
  1412. if hasattr(args, "rir_apply_prob")
  1413. else 1.0,
  1414. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  1415. noise_apply_prob=args.noise_apply_prob
  1416. if hasattr(args, "noise_apply_prob")
  1417. else 1.0,
  1418. noise_db_range=args.noise_db_range
  1419. if hasattr(args, "noise_db_range")
  1420. else "13_15",
  1421. speech_volume_normalize=args.speech_volume_normalize
  1422. if hasattr(args, "rir_scp")
  1423. else None,
  1424. )
  1425. else:
  1426. retval = None
  1427. assert check_return_type(retval)
  1428. return retval
  1429. @classmethod
  1430. def required_data_names(
  1431. cls, train: bool = True, inference: bool = False
  1432. ) -> Tuple[str, ...]:
  1433. """Required data depending on task mode.
  1434. Args:
  1435. cls: ASRTransducerTask object.
  1436. train: Training mode.
  1437. inference: Inference mode.
  1438. Return:
  1439. retval: Required task data.
  1440. """
  1441. if not inference:
  1442. retval = ("speech", "text")
  1443. else:
  1444. retval = ("speech",)
  1445. return retval
  1446. @classmethod
  1447. def optional_data_names(
  1448. cls, train: bool = True, inference: bool = False
  1449. ) -> Tuple[str, ...]:
  1450. """Optional data depending on task mode.
  1451. Args:
  1452. cls: ASRTransducerTask object.
  1453. train: Training mode.
  1454. inference: Inference mode.
  1455. Return:
  1456. retval: Optional task data.
  1457. """
  1458. retval = ()
  1459. assert check_return_type(retval)
  1460. return retval
  1461. @classmethod
  1462. def build_model(cls, args: argparse.Namespace) -> TransducerModel:
  1463. """Required data depending on task mode.
  1464. Args:
  1465. cls: ASRTransducerTask object.
  1466. args: Task arguments.
  1467. Return:
  1468. model: ASR Transducer model.
  1469. """
  1470. assert check_argument_types()
  1471. if isinstance(args.token_list, str):
  1472. with open(args.token_list, encoding="utf-8") as f:
  1473. token_list = [line.rstrip() for line in f]
  1474. # Overwriting token_list to keep it as "portable".
  1475. args.token_list = list(token_list)
  1476. elif isinstance(args.token_list, (tuple, list)):
  1477. token_list = list(args.token_list)
  1478. else:
  1479. raise RuntimeError("token_list must be str or list")
  1480. vocab_size = len(token_list)
  1481. logging.info(f"Vocabulary size: {vocab_size }")
  1482. # 1. frontend
  1483. if args.input_size is None:
  1484. # Extract features in the model
  1485. frontend_class = frontend_choices.get_class(args.frontend)
  1486. frontend = frontend_class(**args.frontend_conf)
  1487. input_size = frontend.output_size()
  1488. else:
  1489. # Give features from data-loader
  1490. frontend = None
  1491. input_size = args.input_size
  1492. # 2. Data augmentation for spectrogram
  1493. if args.specaug is not None:
  1494. specaug_class = specaug_choices.get_class(args.specaug)
  1495. specaug = specaug_class(**args.specaug_conf)
  1496. else:
  1497. specaug = None
  1498. # 3. Normalization layer
  1499. if args.normalize is not None:
  1500. normalize_class = normalize_choices.get_class(args.normalize)
  1501. normalize = normalize_class(**args.normalize_conf)
  1502. else:
  1503. normalize = None
  1504. # 4. Encoder
  1505. if getattr(args, "encoder", None) is not None:
  1506. encoder_class = encoder_choices.get_class(args.encoder)
  1507. encoder = encoder_class(input_size, **args.encoder_conf)
  1508. else:
  1509. encoder = Encoder(input_size, **args.encoder_conf)
  1510. encoder_output_size = encoder.output_size()
  1511. # 5. Decoder
  1512. rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
  1513. decoder = rnnt_decoder_class(
  1514. vocab_size,
  1515. **args.rnnt_decoder_conf,
  1516. )
  1517. decoder_output_size = decoder.output_size
  1518. if getattr(args, "decoder", None) is not None:
  1519. att_decoder_class = decoder_choices.get_class(args.att_decoder)
  1520. att_decoder = att_decoder_class(
  1521. vocab_size=vocab_size,
  1522. encoder_output_size=encoder_output_size,
  1523. **args.decoder_conf,
  1524. )
  1525. else:
  1526. att_decoder = None
  1527. # 6. Joint Network
  1528. joint_network = JointNetwork(
  1529. vocab_size,
  1530. encoder_output_size,
  1531. decoder_output_size,
  1532. **args.joint_network_conf,
  1533. )
  1534. # 7. Build model
  1535. if encoder.unified_model_training:
  1536. model = UnifiedTransducerModel(
  1537. vocab_size=vocab_size,
  1538. token_list=token_list,
  1539. frontend=frontend,
  1540. specaug=specaug,
  1541. normalize=normalize,
  1542. encoder=encoder,
  1543. decoder=decoder,
  1544. att_decoder=att_decoder,
  1545. joint_network=joint_network,
  1546. **args.model_conf,
  1547. )
  1548. else:
  1549. model = TransducerModel(
  1550. vocab_size=vocab_size,
  1551. token_list=token_list,
  1552. frontend=frontend,
  1553. specaug=specaug,
  1554. normalize=normalize,
  1555. encoder=encoder,
  1556. decoder=decoder,
  1557. att_decoder=att_decoder,
  1558. joint_network=joint_network,
  1559. **args.model_conf,
  1560. )
  1561. # 8. Initialize model
  1562. if args.init is not None:
  1563. raise NotImplementedError(
  1564. "Currently not supported.",
  1565. "Initialization part will be reworked in a short future.",
  1566. )
  1567. #assert check_return_type(model)
  1568. return model