asr.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.ctc import CTC
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.decoder.rnn_decoder import RNNDecoder
  25. from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
  26. from funasr.models.decoder.transformer_decoder import (
  27. DynamicConvolution2DTransformerDecoder, # noqa: H301
  28. )
  29. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolution2DTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import (
  34. LightweightConvolutionTransformerDecoder, # noqa: H301
  35. )
  36. from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
  37. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  38. from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
  39. from funasr.models.e2e_asr import ESPnetASRModel
  40. from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
  41. from funasr.models.e2e_tp import TimestampPredictor
  42. from funasr.models.e2e_asr_mfcca import MFCCA
  43. from funasr.models.e2e_uni_asr import UniASR
  44. from funasr.models.encoder.abs_encoder import AbsEncoder
  45. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  46. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  47. from funasr.models.encoder.rnn_encoder import RNNEncoder
  48. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  49. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  50. from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
  51. from funasr.models.frontend.abs_frontend import AbsFrontend
  52. from funasr.models.frontend.default import DefaultFrontend
  53. from funasr.models.frontend.default import MultiChannelFrontend
  54. from funasr.models.frontend.fused import FusedFrontends
  55. from funasr.models.frontend.s3prl import S3prlFrontend
  56. from funasr.models.frontend.wav_frontend import WavFrontend
  57. from funasr.models.frontend.windowing import SlidingWindow
  58. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  59. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  60. HuggingFaceTransformersPostEncoder, # noqa: H301
  61. )
  62. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
  63. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  64. from funasr.models.preencoder.linear import LinearProjection
  65. from funasr.models.preencoder.sinc import LightweightSincConvs
  66. from funasr.models.specaug.abs_specaug import AbsSpecAug
  67. from funasr.models.specaug.specaug import SpecAug
  68. from funasr.models.specaug.specaug import SpecAugLFR
  69. from funasr.modules.subsampling import Conv1dSubsampling
  70. from funasr.tasks.abs_task import AbsTask
  71. from funasr.text.phoneme_tokenizer import g2p_choices
  72. from funasr.torch_utils.initialize import initialize
  73. from funasr.train.abs_espnet_model import AbsESPnetModel
  74. from funasr.train.class_choices import ClassChoices
  75. from funasr.train.trainer import Trainer
  76. from funasr.utils.get_default_kwargs import get_default_kwargs
  77. from funasr.utils.nested_dict_action import NestedDictAction
  78. from funasr.utils.types import float_or_none
  79. from funasr.utils.types import int_or_none
  80. from funasr.utils.types import str2bool
  81. from funasr.utils.types import str_or_none
  82. frontend_choices = ClassChoices(
  83. name="frontend",
  84. classes=dict(
  85. default=DefaultFrontend,
  86. sliding_window=SlidingWindow,
  87. s3prl=S3prlFrontend,
  88. fused=FusedFrontends,
  89. wav_frontend=WavFrontend,
  90. multichannelfrontend=MultiChannelFrontend,
  91. ),
  92. type_check=AbsFrontend,
  93. default="default",
  94. )
  95. specaug_choices = ClassChoices(
  96. name="specaug",
  97. classes=dict(
  98. specaug=SpecAug,
  99. specaug_lfr=SpecAugLFR,
  100. ),
  101. type_check=AbsSpecAug,
  102. default=None,
  103. optional=True,
  104. )
  105. normalize_choices = ClassChoices(
  106. "normalize",
  107. classes=dict(
  108. global_mvn=GlobalMVN,
  109. utterance_mvn=UtteranceMVN,
  110. ),
  111. type_check=AbsNormalize,
  112. default=None,
  113. optional=True,
  114. )
  115. model_choices = ClassChoices(
  116. "model",
  117. classes=dict(
  118. asr=ESPnetASRModel,
  119. uniasr=UniASR,
  120. paraformer=Paraformer,
  121. paraformer_online=ParaformerOnline,
  122. paraformer_bert=ParaformerBert,
  123. bicif_paraformer=BiCifParaformer,
  124. contextual_paraformer=ContextualParaformer,
  125. mfcca=MFCCA,
  126. timestamp_prediction=TimestampPredictor,
  127. ),
  128. type_check=AbsESPnetModel,
  129. default="asr",
  130. )
  131. preencoder_choices = ClassChoices(
  132. name="preencoder",
  133. classes=dict(
  134. sinc=LightweightSincConvs,
  135. linear=LinearProjection,
  136. ),
  137. type_check=AbsPreEncoder,
  138. default=None,
  139. optional=True,
  140. )
  141. encoder_choices = ClassChoices(
  142. "encoder",
  143. classes=dict(
  144. conformer=ConformerEncoder,
  145. transformer=TransformerEncoder,
  146. rnn=RNNEncoder,
  147. sanm=SANMEncoder,
  148. sanm_chunk_opt=SANMEncoderChunkOpt,
  149. data2vec_encoder=Data2VecEncoder,
  150. mfcca_enc=MFCCAEncoder,
  151. ),
  152. type_check=AbsEncoder,
  153. default="rnn",
  154. )
  155. encoder_choices2 = ClassChoices(
  156. "encoder2",
  157. classes=dict(
  158. conformer=ConformerEncoder,
  159. transformer=TransformerEncoder,
  160. rnn=RNNEncoder,
  161. sanm=SANMEncoder,
  162. sanm_chunk_opt=SANMEncoderChunkOpt,
  163. ),
  164. type_check=AbsEncoder,
  165. default="rnn",
  166. )
  167. postencoder_choices = ClassChoices(
  168. name="postencoder",
  169. classes=dict(
  170. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  171. ),
  172. type_check=AbsPostEncoder,
  173. default=None,
  174. optional=True,
  175. )
  176. decoder_choices = ClassChoices(
  177. "decoder",
  178. classes=dict(
  179. transformer=TransformerDecoder,
  180. lightweight_conv=LightweightConvolutionTransformerDecoder,
  181. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  182. dynamic_conv=DynamicConvolutionTransformerDecoder,
  183. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  184. rnn=RNNDecoder,
  185. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  186. paraformer_decoder_sanm=ParaformerSANMDecoder,
  187. paraformer_decoder_san=ParaformerDecoderSAN,
  188. contextual_paraformer_decoder=ContextualParaformerDecoder,
  189. ),
  190. type_check=AbsDecoder,
  191. default="rnn",
  192. )
  193. decoder_choices2 = ClassChoices(
  194. "decoder2",
  195. classes=dict(
  196. transformer=TransformerDecoder,
  197. lightweight_conv=LightweightConvolutionTransformerDecoder,
  198. lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
  199. dynamic_conv=DynamicConvolutionTransformerDecoder,
  200. dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
  201. rnn=RNNDecoder,
  202. fsmn_scama_opt=FsmnDecoderSCAMAOpt,
  203. paraformer_decoder_sanm=ParaformerSANMDecoder,
  204. ),
  205. type_check=AbsDecoder,
  206. default="rnn",
  207. )
  208. predictor_choices = ClassChoices(
  209. name="predictor",
  210. classes=dict(
  211. cif_predictor=CifPredictor,
  212. ctc_predictor=None,
  213. cif_predictor_v2=CifPredictorV2,
  214. cif_predictor_v3=CifPredictorV3,
  215. ),
  216. type_check=None,
  217. default="cif_predictor",
  218. optional=True,
  219. )
  220. predictor_choices2 = ClassChoices(
  221. name="predictor2",
  222. classes=dict(
  223. cif_predictor=CifPredictor,
  224. ctc_predictor=None,
  225. cif_predictor_v2=CifPredictorV2,
  226. ),
  227. type_check=None,
  228. default="cif_predictor",
  229. optional=True,
  230. )
  231. stride_conv_choices = ClassChoices(
  232. name="stride_conv",
  233. classes=dict(
  234. stride_conv1d=Conv1dSubsampling
  235. ),
  236. type_check=None,
  237. default="stride_conv1d",
  238. optional=True,
  239. )
  240. class ASRTask(AbsTask):
  241. # If you need more than one optimizers, change this value
  242. num_optimizers: int = 1
  243. # Add variable objects configurations
  244. class_choices_list = [
  245. # --frontend and --frontend_conf
  246. frontend_choices,
  247. # --specaug and --specaug_conf
  248. specaug_choices,
  249. # --normalize and --normalize_conf
  250. normalize_choices,
  251. # --model and --model_conf
  252. model_choices,
  253. # --preencoder and --preencoder_conf
  254. preencoder_choices,
  255. # --encoder and --encoder_conf
  256. encoder_choices,
  257. # --postencoder and --postencoder_conf
  258. postencoder_choices,
  259. # --decoder and --decoder_conf
  260. decoder_choices,
  261. ]
  262. # If you need to modify train() or eval() procedures, change Trainer class here
  263. trainer = Trainer
  264. @classmethod
  265. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  266. group = parser.add_argument_group(description="Task related")
  267. # NOTE(kamo): add_arguments(..., required=True) can't be used
  268. # to provide --print_config mode. Instead of it, do as
  269. # required = parser.get_default("required")
  270. # required += ["token_list"]
  271. group.add_argument(
  272. "--token_list",
  273. type=str_or_none,
  274. default=None,
  275. help="A text mapping int-id to token",
  276. )
  277. group.add_argument(
  278. "--split_with_space",
  279. type=str2bool,
  280. default=True,
  281. help="whether to split text using <space>",
  282. )
  283. group.add_argument(
  284. "--seg_dict_file",
  285. type=str,
  286. default=None,
  287. help="seg_dict_file for text processing",
  288. )
  289. group.add_argument(
  290. "--init",
  291. type=lambda x: str_or_none(x.lower()),
  292. default=None,
  293. help="The initialization method",
  294. choices=[
  295. "chainer",
  296. "xavier_uniform",
  297. "xavier_normal",
  298. "kaiming_uniform",
  299. "kaiming_normal",
  300. None,
  301. ],
  302. )
  303. group.add_argument(
  304. "--input_size",
  305. type=int_or_none,
  306. default=None,
  307. help="The number of input dimension of the feature",
  308. )
  309. group.add_argument(
  310. "--ctc_conf",
  311. action=NestedDictAction,
  312. default=get_default_kwargs(CTC),
  313. help="The keyword arguments for CTC class.",
  314. )
  315. group.add_argument(
  316. "--joint_net_conf",
  317. action=NestedDictAction,
  318. default=None,
  319. help="The keyword arguments for joint network class.",
  320. )
  321. group = parser.add_argument_group(description="Preprocess related")
  322. group.add_argument(
  323. "--use_preprocessor",
  324. type=str2bool,
  325. default=True,
  326. help="Apply preprocessing to data or not",
  327. )
  328. group.add_argument(
  329. "--token_type",
  330. type=str,
  331. default="bpe",
  332. choices=["bpe", "char", "word", "phn"],
  333. help="The text will be tokenized " "in the specified level token",
  334. )
  335. group.add_argument(
  336. "--bpemodel",
  337. type=str_or_none,
  338. default=None,
  339. help="The model file of sentencepiece",
  340. )
  341. parser.add_argument(
  342. "--non_linguistic_symbols",
  343. type=str_or_none,
  344. default=None,
  345. help="non_linguistic_symbols file path",
  346. )
  347. parser.add_argument(
  348. "--cleaner",
  349. type=str_or_none,
  350. choices=[None, "tacotron", "jaconv", "vietnamese"],
  351. default=None,
  352. help="Apply text cleaning",
  353. )
  354. parser.add_argument(
  355. "--g2p",
  356. type=str_or_none,
  357. choices=g2p_choices,
  358. default=None,
  359. help="Specify g2p method if --token_type=phn",
  360. )
  361. parser.add_argument(
  362. "--speech_volume_normalize",
  363. type=float_or_none,
  364. default=None,
  365. help="Scale the maximum amplitude to the given value.",
  366. )
  367. parser.add_argument(
  368. "--rir_scp",
  369. type=str_or_none,
  370. default=None,
  371. help="The file path of rir scp file.",
  372. )
  373. parser.add_argument(
  374. "--rir_apply_prob",
  375. type=float,
  376. default=1.0,
  377. help="THe probability for applying RIR convolution.",
  378. )
  379. parser.add_argument(
  380. "--cmvn_file",
  381. type=str_or_none,
  382. default=None,
  383. help="The file path of noise scp file.",
  384. )
  385. parser.add_argument(
  386. "--noise_scp",
  387. type=str_or_none,
  388. default=None,
  389. help="The file path of noise scp file.",
  390. )
  391. parser.add_argument(
  392. "--noise_apply_prob",
  393. type=float,
  394. default=1.0,
  395. help="The probability applying Noise adding.",
  396. )
  397. parser.add_argument(
  398. "--noise_db_range",
  399. type=str,
  400. default="13_15",
  401. help="The range of noise decibel level.",
  402. )
  403. for class_choices in cls.class_choices_list:
  404. # Append --<name> and --<name>_conf.
  405. # e.g. --encoder and --encoder_conf
  406. class_choices.add_arguments(group)
  407. @classmethod
  408. def build_collate_fn(
  409. cls, args: argparse.Namespace, train: bool
  410. ) -> Callable[
  411. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  412. Tuple[List[str], Dict[str, torch.Tensor]],
  413. ]:
  414. assert check_argument_types()
  415. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  416. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  417. @classmethod
  418. def build_preprocess_fn(
  419. cls, args: argparse.Namespace, train: bool
  420. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  421. assert check_argument_types()
  422. if args.use_preprocessor:
  423. retval = CommonPreprocessor(
  424. train=train,
  425. token_type=args.token_type,
  426. token_list=args.token_list,
  427. bpemodel=args.bpemodel,
  428. non_linguistic_symbols=args.non_linguistic_symbols,
  429. text_cleaner=args.cleaner,
  430. g2p_type=args.g2p,
  431. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  432. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  433. # NOTE(kamo): Check attribute existence for backward compatibility
  434. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  435. rir_apply_prob=args.rir_apply_prob
  436. if hasattr(args, "rir_apply_prob")
  437. else 1.0,
  438. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  439. noise_apply_prob=args.noise_apply_prob
  440. if hasattr(args, "noise_apply_prob")
  441. else 1.0,
  442. noise_db_range=args.noise_db_range
  443. if hasattr(args, "noise_db_range")
  444. else "13_15",
  445. speech_volume_normalize=args.speech_volume_normalize
  446. if hasattr(args, "rir_scp")
  447. else None,
  448. )
  449. else:
  450. retval = None
  451. assert check_return_type(retval)
  452. return retval
  453. @classmethod
  454. def required_data_names(
  455. cls, train: bool = True, inference: bool = False
  456. ) -> Tuple[str, ...]:
  457. if not inference:
  458. retval = ("speech", "text")
  459. else:
  460. # Recognition mode
  461. retval = ("speech",)
  462. return retval
  463. @classmethod
  464. def optional_data_names(
  465. cls, train: bool = True, inference: bool = False
  466. ) -> Tuple[str, ...]:
  467. retval = ()
  468. assert check_return_type(retval)
  469. return retval
  470. @classmethod
  471. def build_model(cls, args: argparse.Namespace):
  472. assert check_argument_types()
  473. if isinstance(args.token_list, str):
  474. with open(args.token_list, encoding="utf-8") as f:
  475. token_list = [line.rstrip() for line in f]
  476. # Overwriting token_list to keep it as "portable".
  477. args.token_list = list(token_list)
  478. elif isinstance(args.token_list, (tuple, list)):
  479. token_list = list(args.token_list)
  480. else:
  481. raise RuntimeError("token_list must be str or list")
  482. vocab_size = len(token_list)
  483. logging.info(f"Vocabulary size: {vocab_size}")
  484. # 1. frontend
  485. if args.input_size is None:
  486. # Extract features in the model
  487. frontend_class = frontend_choices.get_class(args.frontend)
  488. if args.frontend == 'wav_frontend':
  489. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  490. else:
  491. frontend = frontend_class(**args.frontend_conf)
  492. input_size = frontend.output_size()
  493. else:
  494. # Give features from data-loader
  495. args.frontend = None
  496. args.frontend_conf = {}
  497. frontend = None
  498. input_size = args.input_size
  499. # 2. Data augmentation for spectrogram
  500. if args.specaug is not None:
  501. specaug_class = specaug_choices.get_class(args.specaug)
  502. specaug = specaug_class(**args.specaug_conf)
  503. else:
  504. specaug = None
  505. # 3. Normalization layer
  506. if args.normalize is not None:
  507. normalize_class = normalize_choices.get_class(args.normalize)
  508. normalize = normalize_class(**args.normalize_conf)
  509. else:
  510. normalize = None
  511. # 4. Pre-encoder input block
  512. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  513. if getattr(args, "preencoder", None) is not None:
  514. preencoder_class = preencoder_choices.get_class(args.preencoder)
  515. preencoder = preencoder_class(**args.preencoder_conf)
  516. input_size = preencoder.output_size()
  517. else:
  518. preencoder = None
  519. # 5. Encoder
  520. encoder_class = encoder_choices.get_class(args.encoder)
  521. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  522. # 6. Post-encoder block
  523. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  524. encoder_output_size = encoder.output_size()
  525. if getattr(args, "postencoder", None) is not None:
  526. postencoder_class = postencoder_choices.get_class(args.postencoder)
  527. postencoder = postencoder_class(
  528. input_size=encoder_output_size, **args.postencoder_conf
  529. )
  530. encoder_output_size = postencoder.output_size()
  531. else:
  532. postencoder = None
  533. # 7. Decoder
  534. decoder_class = decoder_choices.get_class(args.decoder)
  535. decoder = decoder_class(
  536. vocab_size=vocab_size,
  537. encoder_output_size=encoder_output_size,
  538. **args.decoder_conf,
  539. )
  540. # 8. CTC
  541. ctc = CTC(
  542. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  543. )
  544. # 9. Build model
  545. try:
  546. model_class = model_choices.get_class(args.model)
  547. except AttributeError:
  548. model_class = model_choices.get_class("asr")
  549. model = model_class(
  550. vocab_size=vocab_size,
  551. frontend=frontend,
  552. specaug=specaug,
  553. normalize=normalize,
  554. preencoder=preencoder,
  555. encoder=encoder,
  556. postencoder=postencoder,
  557. decoder=decoder,
  558. ctc=ctc,
  559. token_list=token_list,
  560. **args.model_conf,
  561. )
  562. # 10. Initialize
  563. if args.init is not None:
  564. initialize(model, args.init)
  565. assert check_return_type(model)
  566. return model
  567. class ASRTaskUniASR(ASRTask):
  568. # If you need more than one optimizers, change this value
  569. num_optimizers: int = 1
  570. # Add variable objects configurations
  571. class_choices_list = [
  572. # --frontend and --frontend_conf
  573. frontend_choices,
  574. # --specaug and --specaug_conf
  575. specaug_choices,
  576. # --normalize and --normalize_conf
  577. normalize_choices,
  578. # --model and --model_conf
  579. model_choices,
  580. # --preencoder and --preencoder_conf
  581. preencoder_choices,
  582. # --encoder and --encoder_conf
  583. encoder_choices,
  584. # --postencoder and --postencoder_conf
  585. postencoder_choices,
  586. # --decoder and --decoder_conf
  587. decoder_choices,
  588. # --predictor and --predictor_conf
  589. predictor_choices,
  590. # --encoder2 and --encoder2_conf
  591. encoder_choices2,
  592. # --decoder2 and --decoder2_conf
  593. decoder_choices2,
  594. # --predictor2 and --predictor2_conf
  595. predictor_choices2,
  596. # --stride_conv and --stride_conv_conf
  597. stride_conv_choices,
  598. ]
  599. # If you need to modify train() or eval() procedures, change Trainer class here
  600. trainer = Trainer
  601. @classmethod
  602. def build_model(cls, args: argparse.Namespace):
  603. assert check_argument_types()
  604. if isinstance(args.token_list, str):
  605. with open(args.token_list, encoding="utf-8") as f:
  606. token_list = [line.rstrip() for line in f]
  607. # Overwriting token_list to keep it as "portable".
  608. args.token_list = list(token_list)
  609. elif isinstance(args.token_list, (tuple, list)):
  610. token_list = list(args.token_list)
  611. else:
  612. raise RuntimeError("token_list must be str or list")
  613. vocab_size = len(token_list)
  614. logging.info(f"Vocabulary size: {vocab_size}")
  615. # 1. frontend
  616. if args.input_size is None:
  617. # Extract features in the model
  618. frontend_class = frontend_choices.get_class(args.frontend)
  619. if args.frontend == 'wav_frontend':
  620. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  621. else:
  622. frontend = frontend_class(**args.frontend_conf)
  623. input_size = frontend.output_size()
  624. else:
  625. # Give features from data-loader
  626. args.frontend = None
  627. args.frontend_conf = {}
  628. frontend = None
  629. input_size = args.input_size
  630. # 2. Data augmentation for spectrogram
  631. if args.specaug is not None:
  632. specaug_class = specaug_choices.get_class(args.specaug)
  633. specaug = specaug_class(**args.specaug_conf)
  634. else:
  635. specaug = None
  636. # 3. Normalization layer
  637. if args.normalize is not None:
  638. normalize_class = normalize_choices.get_class(args.normalize)
  639. normalize = normalize_class(**args.normalize_conf)
  640. else:
  641. normalize = None
  642. # 4. Pre-encoder input block
  643. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  644. if getattr(args, "preencoder", None) is not None:
  645. preencoder_class = preencoder_choices.get_class(args.preencoder)
  646. preencoder = preencoder_class(**args.preencoder_conf)
  647. input_size = preencoder.output_size()
  648. else:
  649. preencoder = None
  650. # 5. Encoder
  651. encoder_class = encoder_choices.get_class(args.encoder)
  652. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  653. encoder_output_size = encoder.output_size()
  654. stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
  655. stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder_output_size,
  656. odim=input_size + encoder_output_size)
  657. stride_conv_output_size = stride_conv.output_size()
  658. # 6. Encoder2
  659. encoder_class2 = encoder_choices2.get_class(args.encoder2)
  660. encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
  661. # 7. Post-encoder block
  662. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  663. encoder_output_size2 = encoder2.output_size()
  664. if getattr(args, "postencoder", None) is not None:
  665. postencoder_class = postencoder_choices.get_class(args.postencoder)
  666. postencoder = postencoder_class(
  667. input_size=encoder_output_size, **args.postencoder_conf
  668. )
  669. encoder_output_size = postencoder.output_size()
  670. else:
  671. postencoder = None
  672. # 8. Decoder & Decoder2
  673. decoder_class = decoder_choices.get_class(args.decoder)
  674. decoder_class2 = decoder_choices2.get_class(args.decoder2)
  675. decoder = decoder_class(
  676. vocab_size=vocab_size,
  677. encoder_output_size=encoder_output_size,
  678. **args.decoder_conf,
  679. )
  680. decoder2 = decoder_class2(
  681. vocab_size=vocab_size,
  682. encoder_output_size=encoder_output_size2,
  683. **args.decoder2_conf,
  684. )
  685. # 9. CTC
  686. ctc = CTC(
  687. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  688. )
  689. ctc2 = CTC(
  690. odim=vocab_size, encoder_output_size=encoder_output_size2, **args.ctc_conf
  691. )
  692. # 10. Predictor
  693. predictor_class = predictor_choices.get_class(args.predictor)
  694. predictor = predictor_class(**args.predictor_conf)
  695. predictor_class = predictor_choices2.get_class(args.predictor2)
  696. predictor2 = predictor_class(**args.predictor2_conf)
  697. # 11. Build model
  698. try:
  699. model_class = model_choices.get_class(args.model)
  700. except AttributeError:
  701. model_class = model_choices.get_class("asr")
  702. model = model_class(
  703. vocab_size=vocab_size,
  704. frontend=frontend,
  705. specaug=specaug,
  706. normalize=normalize,
  707. preencoder=preencoder,
  708. encoder=encoder,
  709. postencoder=postencoder,
  710. decoder=decoder,
  711. ctc=ctc,
  712. token_list=token_list,
  713. predictor=predictor,
  714. ctc2=ctc2,
  715. encoder2=encoder2,
  716. decoder2=decoder2,
  717. predictor2=predictor2,
  718. stride_conv=stride_conv,
  719. **args.model_conf,
  720. )
  721. # 12. Initialize
  722. if args.init is not None:
  723. initialize(model, args.init)
  724. assert check_return_type(model)
  725. return model
  726. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  727. @classmethod
  728. def build_model_from_file(
  729. cls,
  730. config_file: Union[Path, str] = None,
  731. model_file: Union[Path, str] = None,
  732. cmvn_file: Union[Path, str] = None,
  733. device: str = "cpu",
  734. ):
  735. """Build model from the files.
  736. This method is used for inference or fine-tuning.
  737. Args:
  738. config_file: The yaml file saved when training.
  739. model_file: The model file saved when training.
  740. device: Device type, "cpu", "cuda", or "cuda:N".
  741. """
  742. assert check_argument_types()
  743. if config_file is None:
  744. assert model_file is not None, (
  745. "The argument 'model_file' must be provided "
  746. "if the argument 'config_file' is not specified."
  747. )
  748. config_file = Path(model_file).parent / "config.yaml"
  749. else:
  750. config_file = Path(config_file)
  751. with config_file.open("r", encoding="utf-8") as f:
  752. args = yaml.safe_load(f)
  753. if cmvn_file is not None:
  754. args["cmvn_file"] = cmvn_file
  755. args = argparse.Namespace(**args)
  756. model = cls.build_model(args)
  757. if not isinstance(model, AbsESPnetModel):
  758. raise RuntimeError(
  759. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  760. )
  761. model.to(device)
  762. model_dict = dict()
  763. model_name_pth = None
  764. if model_file is not None:
  765. logging.info("model_file is {}".format(model_file))
  766. if device == "cuda":
  767. device = f"cuda:{torch.cuda.current_device()}"
  768. model_dir = os.path.dirname(model_file)
  769. model_name = os.path.basename(model_file)
  770. if "model.ckpt-" in model_name or ".bin" in model_name:
  771. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  772. '.pb')) if ".bin" in model_name else os.path.join(
  773. model_dir, "{}.pb".format(model_name))
  774. if os.path.exists(model_name_pth):
  775. logging.info("model_file is load from pth: {}".format(model_name_pth))
  776. model_dict = torch.load(model_name_pth, map_location=device)
  777. else:
  778. model_dict = cls.convert_tf2torch(model, model_file)
  779. model.load_state_dict(model_dict)
  780. else:
  781. model_dict = torch.load(model_file, map_location=device)
  782. model.load_state_dict(model_dict)
  783. if model_name_pth is not None and not os.path.exists(model_name_pth):
  784. torch.save(model_dict, model_name_pth)
  785. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  786. return model, args
  787. @classmethod
  788. def convert_tf2torch(
  789. cls,
  790. model,
  791. ckpt,
  792. ):
  793. logging.info("start convert tf model to torch model")
  794. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  795. var_dict_tf = load_tf_dict(ckpt)
  796. var_dict_torch = model.state_dict()
  797. var_dict_torch_update = dict()
  798. # encoder
  799. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  800. var_dict_torch_update.update(var_dict_torch_update_local)
  801. # predictor
  802. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  803. var_dict_torch_update.update(var_dict_torch_update_local)
  804. # decoder
  805. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  806. var_dict_torch_update.update(var_dict_torch_update_local)
  807. # encoder2
  808. var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  809. var_dict_torch_update.update(var_dict_torch_update_local)
  810. # predictor2
  811. var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
  812. var_dict_torch_update.update(var_dict_torch_update_local)
  813. # decoder2
  814. var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
  815. var_dict_torch_update.update(var_dict_torch_update_local)
  816. # stride_conv
  817. var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
  818. var_dict_torch_update.update(var_dict_torch_update_local)
  819. return var_dict_torch_update
  820. class ASRTaskParaformer(ASRTask):
  821. # If you need more than one optimizers, change this value
  822. num_optimizers: int = 1
  823. # Add variable objects configurations
  824. class_choices_list = [
  825. # --frontend and --frontend_conf
  826. frontend_choices,
  827. # --specaug and --specaug_conf
  828. specaug_choices,
  829. # --normalize and --normalize_conf
  830. normalize_choices,
  831. # --model and --model_conf
  832. model_choices,
  833. # --preencoder and --preencoder_conf
  834. preencoder_choices,
  835. # --encoder and --encoder_conf
  836. encoder_choices,
  837. # --postencoder and --postencoder_conf
  838. postencoder_choices,
  839. # --decoder and --decoder_conf
  840. decoder_choices,
  841. # --predictor and --predictor_conf
  842. predictor_choices,
  843. ]
  844. # If you need to modify train() or eval() procedures, change Trainer class here
  845. trainer = Trainer
  846. @classmethod
  847. def build_model(cls, args: argparse.Namespace):
  848. assert check_argument_types()
  849. if isinstance(args.token_list, str):
  850. with open(args.token_list, encoding="utf-8") as f:
  851. token_list = [line.rstrip() for line in f]
  852. # Overwriting token_list to keep it as "portable".
  853. args.token_list = list(token_list)
  854. elif isinstance(args.token_list, (tuple, list)):
  855. token_list = list(args.token_list)
  856. else:
  857. raise RuntimeError("token_list must be str or list")
  858. vocab_size = len(token_list)
  859. logging.info(f"Vocabulary size: {vocab_size}")
  860. # 1. frontend
  861. if args.input_size is None:
  862. # Extract features in the model
  863. frontend_class = frontend_choices.get_class(args.frontend)
  864. if args.frontend == 'wav_frontend':
  865. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  866. else:
  867. frontend = frontend_class(**args.frontend_conf)
  868. input_size = frontend.output_size()
  869. else:
  870. # Give features from data-loader
  871. args.frontend = None
  872. args.frontend_conf = {}
  873. frontend = None
  874. input_size = args.input_size
  875. # 2. Data augmentation for spectrogram
  876. if args.specaug is not None:
  877. specaug_class = specaug_choices.get_class(args.specaug)
  878. specaug = specaug_class(**args.specaug_conf)
  879. else:
  880. specaug = None
  881. # 3. Normalization layer
  882. if args.normalize is not None:
  883. normalize_class = normalize_choices.get_class(args.normalize)
  884. normalize = normalize_class(**args.normalize_conf)
  885. else:
  886. normalize = None
  887. # 4. Pre-encoder input block
  888. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  889. if getattr(args, "preencoder", None) is not None:
  890. preencoder_class = preencoder_choices.get_class(args.preencoder)
  891. preencoder = preencoder_class(**args.preencoder_conf)
  892. input_size = preencoder.output_size()
  893. else:
  894. preencoder = None
  895. # 5. Encoder
  896. encoder_class = encoder_choices.get_class(args.encoder)
  897. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  898. # 6. Post-encoder block
  899. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  900. encoder_output_size = encoder.output_size()
  901. if getattr(args, "postencoder", None) is not None:
  902. postencoder_class = postencoder_choices.get_class(args.postencoder)
  903. postencoder = postencoder_class(
  904. input_size=encoder_output_size, **args.postencoder_conf
  905. )
  906. encoder_output_size = postencoder.output_size()
  907. else:
  908. postencoder = None
  909. # 7. Decoder
  910. decoder_class = decoder_choices.get_class(args.decoder)
  911. decoder = decoder_class(
  912. vocab_size=vocab_size,
  913. encoder_output_size=encoder_output_size,
  914. **args.decoder_conf,
  915. )
  916. # 8. CTC
  917. ctc = CTC(
  918. odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf
  919. )
  920. # 9. Predictor
  921. predictor_class = predictor_choices.get_class(args.predictor)
  922. predictor = predictor_class(**args.predictor_conf)
  923. # 10. Build model
  924. try:
  925. model_class = model_choices.get_class(args.model)
  926. except AttributeError:
  927. model_class = model_choices.get_class("asr")
  928. model = model_class(
  929. vocab_size=vocab_size,
  930. frontend=frontend,
  931. specaug=specaug,
  932. normalize=normalize,
  933. preencoder=preencoder,
  934. encoder=encoder,
  935. postencoder=postencoder,
  936. decoder=decoder,
  937. ctc=ctc,
  938. token_list=token_list,
  939. predictor=predictor,
  940. **args.model_conf,
  941. )
  942. # 11. Initialize
  943. if args.init is not None:
  944. initialize(model, args.init)
  945. assert check_return_type(model)
  946. return model
  947. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  948. @classmethod
  949. def build_model_from_file(
  950. cls,
  951. config_file: Union[Path, str] = None,
  952. model_file: Union[Path, str] = None,
  953. cmvn_file: Union[Path, str] = None,
  954. device: str = "cpu",
  955. ):
  956. """Build model from the files.
  957. This method is used for inference or fine-tuning.
  958. Args:
  959. config_file: The yaml file saved when training.
  960. model_file: The model file saved when training.
  961. device: Device type, "cpu", "cuda", or "cuda:N".
  962. """
  963. assert check_argument_types()
  964. if config_file is None:
  965. assert model_file is not None, (
  966. "The argument 'model_file' must be provided "
  967. "if the argument 'config_file' is not specified."
  968. )
  969. config_file = Path(model_file).parent / "config.yaml"
  970. else:
  971. config_file = Path(config_file)
  972. with config_file.open("r", encoding="utf-8") as f:
  973. args = yaml.safe_load(f)
  974. if cmvn_file is not None:
  975. args["cmvn_file"] = cmvn_file
  976. args = argparse.Namespace(**args)
  977. model = cls.build_model(args)
  978. if not isinstance(model, AbsESPnetModel):
  979. raise RuntimeError(
  980. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  981. )
  982. model.to(device)
  983. model_dict = dict()
  984. model_name_pth = None
  985. if model_file is not None:
  986. logging.info("model_file is {}".format(model_file))
  987. if device == "cuda":
  988. device = f"cuda:{torch.cuda.current_device()}"
  989. model_dir = os.path.dirname(model_file)
  990. model_name = os.path.basename(model_file)
  991. if "model.ckpt-" in model_name or ".bin" in model_name:
  992. model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
  993. '.pb')) if ".bin" in model_name else os.path.join(
  994. model_dir, "{}.pb".format(model_name))
  995. if os.path.exists(model_name_pth):
  996. logging.info("model_file is load from pth: {}".format(model_name_pth))
  997. model_dict = torch.load(model_name_pth, map_location=device)
  998. else:
  999. model_dict = cls.convert_tf2torch(model, model_file)
  1000. model.load_state_dict(model_dict)
  1001. else:
  1002. model_dict = torch.load(model_file, map_location=device)
  1003. model.load_state_dict(model_dict)
  1004. if model_name_pth is not None and not os.path.exists(model_name_pth):
  1005. torch.save(model_dict, model_name_pth)
  1006. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  1007. model.to(device)
  1008. return model, args
  1009. @classmethod
  1010. def convert_tf2torch(
  1011. cls,
  1012. model,
  1013. ckpt,
  1014. ):
  1015. logging.info("start convert tf model to torch model")
  1016. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  1017. var_dict_tf = load_tf_dict(ckpt)
  1018. var_dict_torch = model.state_dict()
  1019. var_dict_torch_update = dict()
  1020. # encoder
  1021. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1022. var_dict_torch_update.update(var_dict_torch_update_local)
  1023. # predictor
  1024. var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
  1025. var_dict_torch_update.update(var_dict_torch_update_local)
  1026. # decoder
  1027. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  1028. var_dict_torch_update.update(var_dict_torch_update_local)
  1029. # bias_encoder
  1030. var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
  1031. var_dict_torch_update.update(var_dict_torch_update_local)
  1032. return var_dict_torch_update
  1033. class ASRTaskMFCCA(ASRTask):
  1034. # If you need more than one optimizers, change this value
  1035. num_optimizers: int = 1
  1036. # Add variable objects configurations
  1037. class_choices_list = [
  1038. # --frontend and --frontend_conf
  1039. frontend_choices,
  1040. # --specaug and --specaug_conf
  1041. specaug_choices,
  1042. # --normalize and --normalize_conf
  1043. normalize_choices,
  1044. # --model and --model_conf
  1045. model_choices,
  1046. # --preencoder and --preencoder_conf
  1047. preencoder_choices,
  1048. # --encoder and --encoder_conf
  1049. encoder_choices,
  1050. # --decoder and --decoder_conf
  1051. decoder_choices,
  1052. ]
  1053. # If you need to modify train() or eval() procedures, change Trainer class here
  1054. trainer = Trainer
  1055. @classmethod
  1056. def build_model(cls, args: argparse.Namespace):
  1057. assert check_argument_types()
  1058. if isinstance(args.token_list, str):
  1059. with open(args.token_list, encoding="utf-8") as f:
  1060. token_list = [line.rstrip() for line in f]
  1061. # Overwriting token_list to keep it as "portable".
  1062. args.token_list = list(token_list)
  1063. elif isinstance(args.token_list, (tuple, list)):
  1064. token_list = list(args.token_list)
  1065. else:
  1066. raise RuntimeError("token_list must be str or list")
  1067. vocab_size = len(token_list)
  1068. logging.info(f"Vocabulary size: {vocab_size}")
  1069. # 1. frontend
  1070. if args.input_size is None:
  1071. # Extract features in the model
  1072. frontend_class = frontend_choices.get_class(args.frontend)
  1073. if args.frontend == 'wav_frontend':
  1074. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1075. else:
  1076. frontend = frontend_class(**args.frontend_conf)
  1077. input_size = frontend.output_size()
  1078. else:
  1079. # Give features from data-loader
  1080. args.frontend = None
  1081. args.frontend_conf = {}
  1082. frontend = None
  1083. input_size = args.input_size
  1084. # 2. Data augmentation for spectrogram
  1085. if args.specaug is not None:
  1086. specaug_class = specaug_choices.get_class(args.specaug)
  1087. specaug = specaug_class(**args.specaug_conf)
  1088. else:
  1089. specaug = None
  1090. # 3. Normalization layer
  1091. if args.normalize is not None:
  1092. normalize_class = normalize_choices.get_class(args.normalize)
  1093. normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
  1094. else:
  1095. normalize = None
  1096. # 4. Pre-encoder input block
  1097. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  1098. if getattr(args, "preencoder", None) is not None:
  1099. preencoder_class = preencoder_choices.get_class(args.preencoder)
  1100. preencoder = preencoder_class(**args.preencoder_conf)
  1101. input_size = preencoder.output_size()
  1102. else:
  1103. preencoder = None
  1104. # 5. Encoder
  1105. encoder_class = encoder_choices.get_class(args.encoder)
  1106. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1107. # 7. Decoder
  1108. decoder_class = decoder_choices.get_class(args.decoder)
  1109. decoder = decoder_class(
  1110. vocab_size=vocab_size,
  1111. encoder_output_size=encoder.output_size(),
  1112. **args.decoder_conf,
  1113. )
  1114. # 8. CTC
  1115. ctc = CTC(
  1116. odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
  1117. )
  1118. # 10. Build model
  1119. try:
  1120. model_class = model_choices.get_class(args.model)
  1121. except AttributeError:
  1122. model_class = model_choices.get_class("asr")
  1123. rnnt_decoder = None
  1124. # 8. Build model
  1125. model = model_class(
  1126. vocab_size=vocab_size,
  1127. frontend=frontend,
  1128. specaug=specaug,
  1129. normalize=normalize,
  1130. preencoder=preencoder,
  1131. encoder=encoder,
  1132. decoder=decoder,
  1133. ctc=ctc,
  1134. rnnt_decoder=rnnt_decoder,
  1135. token_list=token_list,
  1136. **args.model_conf,
  1137. )
  1138. # 11. Initialize
  1139. if args.init is not None:
  1140. initialize(model, args.init)
  1141. assert check_return_type(model)
  1142. return model
  1143. class ASRTaskAligner(ASRTaskParaformer):
  1144. # If you need more than one optimizers, change this value
  1145. num_optimizers: int = 1
  1146. # Add variable objects configurations
  1147. class_choices_list = [
  1148. # --frontend and --frontend_conf
  1149. frontend_choices,
  1150. # --model and --model_conf
  1151. model_choices,
  1152. # --encoder and --encoder_conf
  1153. encoder_choices,
  1154. # --decoder and --decoder_conf
  1155. decoder_choices,
  1156. ]
  1157. # If you need to modify train() or eval() procedures, change Trainer class here
  1158. trainer = Trainer
  1159. @classmethod
  1160. def build_model(cls, args: argparse.Namespace):
  1161. assert check_argument_types()
  1162. if isinstance(args.token_list, str):
  1163. with open(args.token_list, encoding="utf-8") as f:
  1164. token_list = [line.rstrip() for line in f]
  1165. # Overwriting token_list to keep it as "portable".
  1166. args.token_list = list(token_list)
  1167. elif isinstance(args.token_list, (tuple, list)):
  1168. token_list = list(args.token_list)
  1169. else:
  1170. raise RuntimeError("token_list must be str or list")
  1171. # 1. frontend
  1172. if args.input_size is None:
  1173. # Extract features in the model
  1174. frontend_class = frontend_choices.get_class(args.frontend)
  1175. if args.frontend == 'wav_frontend':
  1176. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  1177. else:
  1178. frontend = frontend_class(**args.frontend_conf)
  1179. input_size = frontend.output_size()
  1180. else:
  1181. # Give features from data-loader
  1182. args.frontend = None
  1183. args.frontend_conf = {}
  1184. frontend = None
  1185. input_size = args.input_size
  1186. # 2. Encoder
  1187. encoder_class = encoder_choices.get_class(args.encoder)
  1188. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  1189. # 3. Predictor
  1190. predictor_class = predictor_choices.get_class(args.predictor)
  1191. predictor = predictor_class(**args.predictor_conf)
  1192. # 10. Build model
  1193. try:
  1194. model_class = model_choices.get_class(args.model)
  1195. except AttributeError:
  1196. model_class = model_choices.get_class("asr")
  1197. # 8. Build model
  1198. model = model_class(
  1199. frontend=frontend,
  1200. encoder=encoder,
  1201. predictor=predictor,
  1202. token_list=token_list,
  1203. **args.model_conf,
  1204. )
  1205. # 11. Initialize
  1206. if args.init is not None:
  1207. initialize(model, args.init)
  1208. assert check_return_type(model)
  1209. return model
  1210. @classmethod
  1211. def required_data_names(
  1212. cls, train: bool = True, inference: bool = False
  1213. ) -> Tuple[str, ...]:
  1214. retval = ("speech", "text")
  1215. return retval