diar.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.label_aggregation import LabelAggregate
  22. from funasr.layers.utterance_mvn import UtteranceMVN
  23. from funasr.models.e2e_diar_sond import DiarSondModel
  24. from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
  25. from funasr.models.encoder.abs_encoder import AbsEncoder
  26. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  27. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  28. from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
  29. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  30. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  31. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  32. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  33. from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
  34. from funasr.models.encoder.rnn_encoder import RNNEncoder
  35. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  36. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  37. from funasr.models.frontend.abs_frontend import AbsFrontend
  38. from funasr.models.frontend.default import DefaultFrontend
  39. from funasr.models.frontend.fused import FusedFrontends
  40. from funasr.models.frontend.s3prl import S3prlFrontend
  41. from funasr.models.frontend.wav_frontend import WavFrontend
  42. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  43. from funasr.models.frontend.windowing import SlidingWindow
  44. from funasr.models.specaug.abs_specaug import AbsSpecAug
  45. from funasr.models.specaug.specaug import SpecAug
  46. from funasr.models.specaug.specaug import SpecAugLFR
  47. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  48. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  49. from funasr.tasks.abs_task import AbsTask
  50. from funasr.torch_utils.initialize import initialize
  51. from funasr.train.abs_espnet_model import AbsESPnetModel
  52. from funasr.train.class_choices import ClassChoices
  53. from funasr.train.trainer import Trainer
  54. from funasr.utils.types import float_or_none
  55. from funasr.utils.types import int_or_none
  56. from funasr.utils.types import str2bool
  57. from funasr.utils.types import str_or_none
  58. frontend_choices = ClassChoices(
  59. name="frontend",
  60. classes=dict(
  61. default=DefaultFrontend,
  62. sliding_window=SlidingWindow,
  63. s3prl=S3prlFrontend,
  64. fused=FusedFrontends,
  65. wav_frontend=WavFrontend,
  66. wav_frontend_mel23=WavFrontendMel23,
  67. ),
  68. type_check=AbsFrontend,
  69. default="default",
  70. )
  71. specaug_choices = ClassChoices(
  72. name="specaug",
  73. classes=dict(
  74. specaug=SpecAug,
  75. specaug_lfr=SpecAugLFR,
  76. ),
  77. type_check=AbsSpecAug,
  78. default=None,
  79. optional=True,
  80. )
  81. normalize_choices = ClassChoices(
  82. "normalize",
  83. classes=dict(
  84. global_mvn=GlobalMVN,
  85. utterance_mvn=UtteranceMVN,
  86. ),
  87. type_check=AbsNormalize,
  88. default=None,
  89. optional=True,
  90. )
  91. label_aggregator_choices = ClassChoices(
  92. "label_aggregator",
  93. classes=dict(
  94. label_aggregator=LabelAggregate
  95. ),
  96. type_check=torch.nn.Module,
  97. default=None,
  98. optional=True,
  99. )
  100. model_choices = ClassChoices(
  101. "model",
  102. classes=dict(
  103. sond=DiarSondModel,
  104. eend_ola=DiarEENDOLAModel,
  105. ),
  106. type_check=AbsESPnetModel,
  107. default="sond",
  108. )
  109. encoder_choices = ClassChoices(
  110. "encoder",
  111. classes=dict(
  112. conformer=ConformerEncoder,
  113. transformer=TransformerEncoder,
  114. rnn=RNNEncoder,
  115. sanm=SANMEncoder,
  116. san=SelfAttentionEncoder,
  117. fsmn=FsmnEncoder,
  118. conv=ConvEncoder,
  119. resnet34=ResNet34Diar,
  120. resnet34_sp_l2reg=ResNet34SpL2RegDiar,
  121. sanm_chunk_opt=SANMEncoderChunkOpt,
  122. data2vec_encoder=Data2VecEncoder,
  123. ecapa_tdnn=ECAPA_TDNN,
  124. eend_ola_transformer=EENDOLATransformerEncoder,
  125. ),
  126. type_check=torch.nn.Module,
  127. default="resnet34",
  128. )
  129. speaker_encoder_choices = ClassChoices(
  130. "speaker_encoder",
  131. classes=dict(
  132. conformer=ConformerEncoder,
  133. transformer=TransformerEncoder,
  134. rnn=RNNEncoder,
  135. sanm=SANMEncoder,
  136. san=SelfAttentionEncoder,
  137. fsmn=FsmnEncoder,
  138. conv=ConvEncoder,
  139. sanm_chunk_opt=SANMEncoderChunkOpt,
  140. data2vec_encoder=Data2VecEncoder,
  141. ),
  142. type_check=AbsEncoder,
  143. default=None,
  144. optional=True
  145. )
  146. cd_scorer_choices = ClassChoices(
  147. "cd_scorer",
  148. classes=dict(
  149. san=SelfAttentionEncoder,
  150. ),
  151. type_check=AbsEncoder,
  152. default=None,
  153. optional=True,
  154. )
  155. ci_scorer_choices = ClassChoices(
  156. "ci_scorer",
  157. classes=dict(
  158. dot=DotScorer,
  159. cosine=CosScorer,
  160. conv=ConvEncoder,
  161. ),
  162. type_check=torch.nn.Module,
  163. default=None,
  164. optional=True,
  165. )
  166. # decoder is used for output (e.g. post_net in SOND)
  167. decoder_choices = ClassChoices(
  168. "decoder",
  169. classes=dict(
  170. rnn=RNNEncoder,
  171. fsmn=FsmnEncoder,
  172. ),
  173. type_check=torch.nn.Module,
  174. default="fsmn",
  175. )
  176. # encoder_decoder_attractor is used for EEND-OLA
  177. encoder_decoder_attractor_choices = ClassChoices(
  178. "encoder_decoder_attractor",
  179. classes=dict(
  180. eda=EncoderDecoderAttractor,
  181. ),
  182. type_check=torch.nn.Module,
  183. default="eda",
  184. )
  185. class DiarTask(AbsTask):
  186. # If you need more than 1 optimizer, change this value
  187. num_optimizers: int = 1
  188. # Add variable objects configurations
  189. class_choices_list = [
  190. # --frontend and --frontend_conf
  191. frontend_choices,
  192. # --specaug and --specaug_conf
  193. specaug_choices,
  194. # --normalize and --normalize_conf
  195. normalize_choices,
  196. # --label_aggregator and --label_aggregator_conf
  197. label_aggregator_choices,
  198. # --model and --model_conf
  199. model_choices,
  200. # --encoder and --encoder_conf
  201. encoder_choices,
  202. # --speaker_encoder and --speaker_encoder_conf
  203. speaker_encoder_choices,
  204. # --cd_scorer and cd_scorer_conf
  205. cd_scorer_choices,
  206. # --ci_scorer and ci_scorer_conf
  207. ci_scorer_choices,
  208. # --decoder and --decoder_conf
  209. decoder_choices,
  210. ]
  211. # If you need to modify train() or eval() procedures, change Trainer class here
  212. trainer = Trainer
  213. @classmethod
  214. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  215. group = parser.add_argument_group(description="Task related")
  216. # NOTE(kamo): add_arguments(..., required=True) can't be used
  217. # to provide --print_config mode. Instead of it, do as
  218. # required = parser.get_default("required")
  219. # required += ["token_list"]
  220. group.add_argument(
  221. "--token_list",
  222. type=str_or_none,
  223. default=None,
  224. help="A text mapping int-id to token",
  225. )
  226. group.add_argument(
  227. "--split_with_space",
  228. type=str2bool,
  229. default=True,
  230. help="whether to split text using <space>",
  231. )
  232. group.add_argument(
  233. "--seg_dict_file",
  234. type=str,
  235. default=None,
  236. help="seg_dict_file for text processing",
  237. )
  238. group.add_argument(
  239. "--init",
  240. type=lambda x: str_or_none(x.lower()),
  241. default=None,
  242. help="The initialization method",
  243. choices=[
  244. "chainer",
  245. "xavier_uniform",
  246. "xavier_normal",
  247. "kaiming_uniform",
  248. "kaiming_normal",
  249. None,
  250. ],
  251. )
  252. group.add_argument(
  253. "--input_size",
  254. type=int_or_none,
  255. default=None,
  256. help="The number of input dimension of the feature",
  257. )
  258. group = parser.add_argument_group(description="Preprocess related")
  259. group.add_argument(
  260. "--use_preprocessor",
  261. type=str2bool,
  262. default=True,
  263. help="Apply preprocessing to data or not",
  264. )
  265. group.add_argument(
  266. "--token_type",
  267. type=str,
  268. default="char",
  269. choices=["char"],
  270. help="The text will be tokenized in the specified level token",
  271. )
  272. parser.add_argument(
  273. "--speech_volume_normalize",
  274. type=float_or_none,
  275. default=None,
  276. help="Scale the maximum amplitude to the given value.",
  277. )
  278. parser.add_argument(
  279. "--rir_scp",
  280. type=str_or_none,
  281. default=None,
  282. help="The file path of rir scp file.",
  283. )
  284. parser.add_argument(
  285. "--rir_apply_prob",
  286. type=float,
  287. default=1.0,
  288. help="THe probability for applying RIR convolution.",
  289. )
  290. parser.add_argument(
  291. "--cmvn_file",
  292. type=str_or_none,
  293. default=None,
  294. help="The file path of noise scp file.",
  295. )
  296. parser.add_argument(
  297. "--noise_scp",
  298. type=str_or_none,
  299. default=None,
  300. help="The file path of noise scp file.",
  301. )
  302. parser.add_argument(
  303. "--noise_apply_prob",
  304. type=float,
  305. default=1.0,
  306. help="The probability applying Noise adding.",
  307. )
  308. parser.add_argument(
  309. "--noise_db_range",
  310. type=str,
  311. default="13_15",
  312. help="The range of noise decibel level.",
  313. )
  314. for class_choices in cls.class_choices_list:
  315. # Append --<name> and --<name>_conf.
  316. # e.g. --encoder and --encoder_conf
  317. class_choices.add_arguments(group)
  318. @classmethod
  319. def build_collate_fn(
  320. cls, args: argparse.Namespace, train: bool
  321. ) -> Callable[
  322. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  323. Tuple[List[str], Dict[str, torch.Tensor]],
  324. ]:
  325. assert check_argument_types()
  326. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  327. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  328. @classmethod
  329. def build_preprocess_fn(
  330. cls, args: argparse.Namespace, train: bool
  331. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  332. assert check_argument_types()
  333. if args.use_preprocessor:
  334. retval = CommonPreprocessor(
  335. train=train,
  336. token_type=args.token_type,
  337. token_list=args.token_list,
  338. bpemodel=None,
  339. non_linguistic_symbols=None,
  340. text_cleaner=None,
  341. g2p_type=None,
  342. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  343. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  344. # NOTE(kamo): Check attribute existence for backward compatibility
  345. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  346. rir_apply_prob=args.rir_apply_prob
  347. if hasattr(args, "rir_apply_prob")
  348. else 1.0,
  349. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  350. noise_apply_prob=args.noise_apply_prob
  351. if hasattr(args, "noise_apply_prob")
  352. else 1.0,
  353. noise_db_range=args.noise_db_range
  354. if hasattr(args, "noise_db_range")
  355. else "13_15",
  356. speech_volume_normalize=args.speech_volume_normalize
  357. if hasattr(args, "rir_scp")
  358. else None,
  359. )
  360. else:
  361. retval = None
  362. assert check_return_type(retval)
  363. return retval
  364. @classmethod
  365. def required_data_names(
  366. cls, train: bool = True, inference: bool = False
  367. ) -> Tuple[str, ...]:
  368. if not inference:
  369. retval = ("speech", "profile", "binary_labels")
  370. else:
  371. # Recognition mode
  372. retval = ("speech", "profile")
  373. return retval
  374. @classmethod
  375. def optional_data_names(
  376. cls, train: bool = True, inference: bool = False
  377. ) -> Tuple[str, ...]:
  378. retval = ()
  379. assert check_return_type(retval)
  380. return retval
  381. @classmethod
  382. def build_model(cls, args: argparse.Namespace):
  383. assert check_argument_types()
  384. if isinstance(args.token_list, str):
  385. with open(args.token_list, encoding="utf-8") as f:
  386. token_list = [line.rstrip() for line in f]
  387. # Overwriting token_list to keep it as "portable".
  388. args.token_list = list(token_list)
  389. elif isinstance(args.token_list, (tuple, list)):
  390. token_list = list(args.token_list)
  391. else:
  392. raise RuntimeError("token_list must be str or list")
  393. vocab_size = len(token_list)
  394. logging.info(f"Vocabulary size: {vocab_size}")
  395. # 1. frontend
  396. if args.input_size is None:
  397. # Extract features in the model
  398. frontend_class = frontend_choices.get_class(args.frontend)
  399. if args.frontend == 'wav_frontend':
  400. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  401. else:
  402. frontend = frontend_class(**args.frontend_conf)
  403. input_size = frontend.output_size()
  404. else:
  405. # Give features from data-loader
  406. args.frontend = None
  407. args.frontend_conf = {}
  408. frontend = None
  409. input_size = args.input_size
  410. # 2. Data augmentation for spectrogram
  411. if args.specaug is not None:
  412. specaug_class = specaug_choices.get_class(args.specaug)
  413. specaug = specaug_class(**args.specaug_conf)
  414. else:
  415. specaug = None
  416. # 3. Normalization layer
  417. if args.normalize is not None:
  418. normalize_class = normalize_choices.get_class(args.normalize)
  419. normalize = normalize_class(**args.normalize_conf)
  420. else:
  421. normalize = None
  422. # 4. Encoder
  423. encoder_class = encoder_choices.get_class(args.encoder)
  424. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  425. # 5. speaker encoder
  426. if getattr(args, "speaker_encoder", None) is not None:
  427. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  428. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  429. else:
  430. speaker_encoder = None
  431. # 6. CI & CD scorer
  432. if getattr(args, "ci_scorer", None) is not None:
  433. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  434. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  435. else:
  436. ci_scorer = None
  437. if getattr(args, "cd_scorer", None) is not None:
  438. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  439. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  440. else:
  441. cd_scorer = None
  442. # 7. Decoder
  443. decoder_class = decoder_choices.get_class(args.decoder)
  444. decoder = decoder_class(**args.decoder_conf)
  445. if getattr(args, "label_aggregator", None) is not None:
  446. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  447. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  448. else:
  449. label_aggregator = None
  450. # 9. Build model
  451. model_class = model_choices.get_class(args.model)
  452. model = model_class(
  453. vocab_size=vocab_size,
  454. frontend=frontend,
  455. specaug=specaug,
  456. normalize=normalize,
  457. label_aggregator=label_aggregator,
  458. encoder=encoder,
  459. speaker_encoder=speaker_encoder,
  460. ci_scorer=ci_scorer,
  461. cd_scorer=cd_scorer,
  462. decoder=decoder,
  463. token_list=token_list,
  464. **args.model_conf,
  465. )
  466. # 10. Initialize
  467. if args.init is not None:
  468. initialize(model, args.init)
  469. assert check_return_type(model)
  470. return model
  471. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  472. @classmethod
  473. def build_model_from_file(
  474. cls,
  475. config_file: Union[Path, str] = None,
  476. model_file: Union[Path, str] = None,
  477. cmvn_file: Union[Path, str] = None,
  478. device: str = "cpu",
  479. ):
  480. """Build model from the files.
  481. This method is used for inference or fine-tuning.
  482. Args:
  483. config_file: The yaml file saved when training.
  484. model_file: The model file saved when training.
  485. cmvn_file: The cmvn file for front-end
  486. device: Device type, "cpu", "cuda", or "cuda:N".
  487. """
  488. assert check_argument_types()
  489. if config_file is None:
  490. assert model_file is not None, (
  491. "The argument 'model_file' must be provided "
  492. "if the argument 'config_file' is not specified."
  493. )
  494. config_file = Path(model_file).parent / "config.yaml"
  495. else:
  496. config_file = Path(config_file)
  497. with config_file.open("r", encoding="utf-8") as f:
  498. args = yaml.safe_load(f)
  499. if cmvn_file is not None:
  500. args["cmvn_file"] = cmvn_file
  501. args = argparse.Namespace(**args)
  502. model = cls.build_model(args)
  503. if not isinstance(model, AbsESPnetModel):
  504. raise RuntimeError(
  505. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  506. )
  507. model.to(device)
  508. model_dict = dict()
  509. model_name_pth = None
  510. if model_file is not None:
  511. logging.info("model_file is {}".format(model_file))
  512. if device == "cuda":
  513. device = f"cuda:{torch.cuda.current_device()}"
  514. model_dir = os.path.dirname(model_file)
  515. model_name = os.path.basename(model_file)
  516. if "model.ckpt-" in model_name or ".bin" in model_name:
  517. if ".bin" in model_name:
  518. model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
  519. else:
  520. model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
  521. if os.path.exists(model_name_pth):
  522. logging.info("model_file is load from pth: {}".format(model_name_pth))
  523. model_dict = torch.load(model_name_pth, map_location=device)
  524. else:
  525. model_dict = cls.convert_tf2torch(model, model_file)
  526. model.load_state_dict(model_dict)
  527. else:
  528. model_dict = torch.load(model_file, map_location=device)
  529. model.load_state_dict(model_dict)
  530. if model_name_pth is not None and not os.path.exists(model_name_pth):
  531. torch.save(model_dict, model_name_pth)
  532. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  533. return model, args
  534. @classmethod
  535. def convert_tf2torch(
  536. cls,
  537. model,
  538. ckpt,
  539. ):
  540. logging.info("start convert tf model to torch model")
  541. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  542. var_dict_tf = load_tf_dict(ckpt)
  543. var_dict_torch = model.state_dict()
  544. var_dict_torch_update = dict()
  545. # speech encoder
  546. if model.encoder is not None:
  547. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  548. var_dict_torch_update.update(var_dict_torch_update_local)
  549. # speaker encoder
  550. if model.speaker_encoder is not None:
  551. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  552. var_dict_torch_update.update(var_dict_torch_update_local)
  553. # cd scorer
  554. if model.cd_scorer is not None:
  555. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  556. var_dict_torch_update.update(var_dict_torch_update_local)
  557. # ci scorer
  558. if model.ci_scorer is not None:
  559. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  560. var_dict_torch_update.update(var_dict_torch_update_local)
  561. # decoder
  562. if model.decoder is not None:
  563. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  564. var_dict_torch_update.update(var_dict_torch_update_local)
  565. return var_dict_torch_update
  566. class EENDOLADiarTask(AbsTask):
  567. # If you need more than 1 optimizer, change this value
  568. num_optimizers: int = 1
  569. # Add variable objects configurations
  570. class_choices_list = [
  571. # --frontend and --frontend_conf
  572. frontend_choices,
  573. # --specaug and --specaug_conf
  574. model_choices,
  575. # --encoder and --encoder_conf
  576. encoder_choices,
  577. # --speaker_encoder and --speaker_encoder_conf
  578. encoder_decoder_attractor_choices,
  579. ]
  580. # If you need to modify train() or eval() procedures, change Trainer class here
  581. trainer = Trainer
  582. @classmethod
  583. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  584. group = parser.add_argument_group(description="Task related")
  585. # NOTE(kamo): add_arguments(..., required=True) can't be used
  586. # to provide --print_config mode. Instead of it, do as
  587. # required = parser.get_default("required")
  588. # required += ["token_list"]
  589. group.add_argument(
  590. "--token_list",
  591. type=str_or_none,
  592. default=None,
  593. help="A text mapping int-id to token",
  594. )
  595. group.add_argument(
  596. "--split_with_space",
  597. type=str2bool,
  598. default=True,
  599. help="whether to split text using <space>",
  600. )
  601. group.add_argument(
  602. "--seg_dict_file",
  603. type=str,
  604. default=None,
  605. help="seg_dict_file for text processing",
  606. )
  607. group.add_argument(
  608. "--init",
  609. type=lambda x: str_or_none(x.lower()),
  610. default=None,
  611. help="The initialization method",
  612. choices=[
  613. "chainer",
  614. "xavier_uniform",
  615. "xavier_normal",
  616. "kaiming_uniform",
  617. "kaiming_normal",
  618. None,
  619. ],
  620. )
  621. group.add_argument(
  622. "--input_size",
  623. type=int_or_none,
  624. default=None,
  625. help="The number of input dimension of the feature",
  626. )
  627. group = parser.add_argument_group(description="Preprocess related")
  628. group.add_argument(
  629. "--use_preprocessor",
  630. type=str2bool,
  631. default=True,
  632. help="Apply preprocessing to data or not",
  633. )
  634. group.add_argument(
  635. "--token_type",
  636. type=str,
  637. default="char",
  638. choices=["char"],
  639. help="The text will be tokenized in the specified level token",
  640. )
  641. parser.add_argument(
  642. "--speech_volume_normalize",
  643. type=float_or_none,
  644. default=None,
  645. help="Scale the maximum amplitude to the given value.",
  646. )
  647. parser.add_argument(
  648. "--rir_scp",
  649. type=str_or_none,
  650. default=None,
  651. help="The file path of rir scp file.",
  652. )
  653. parser.add_argument(
  654. "--rir_apply_prob",
  655. type=float,
  656. default=1.0,
  657. help="THe probability for applying RIR convolution.",
  658. )
  659. parser.add_argument(
  660. "--cmvn_file",
  661. type=str_or_none,
  662. default=None,
  663. help="The file path of noise scp file.",
  664. )
  665. parser.add_argument(
  666. "--noise_scp",
  667. type=str_or_none,
  668. default=None,
  669. help="The file path of noise scp file.",
  670. )
  671. parser.add_argument(
  672. "--noise_apply_prob",
  673. type=float,
  674. default=1.0,
  675. help="The probability applying Noise adding.",
  676. )
  677. parser.add_argument(
  678. "--noise_db_range",
  679. type=str,
  680. default="13_15",
  681. help="The range of noise decibel level.",
  682. )
  683. for class_choices in cls.class_choices_list:
  684. # Append --<name> and --<name>_conf.
  685. # e.g. --encoder and --encoder_conf
  686. class_choices.add_arguments(group)
  687. @classmethod
  688. def build_collate_fn(
  689. cls, args: argparse.Namespace, train: bool
  690. ) -> Callable[
  691. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  692. Tuple[List[str], Dict[str, torch.Tensor]],
  693. ]:
  694. assert check_argument_types()
  695. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  696. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  697. @classmethod
  698. def build_preprocess_fn(
  699. cls, args: argparse.Namespace, train: bool
  700. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  701. assert check_argument_types()
  702. if args.use_preprocessor:
  703. retval = CommonPreprocessor(
  704. train=train,
  705. token_type=args.token_type,
  706. token_list=args.token_list,
  707. bpemodel=None,
  708. non_linguistic_symbols=None,
  709. text_cleaner=None,
  710. g2p_type=None,
  711. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  712. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  713. # NOTE(kamo): Check attribute existence for backward compatibility
  714. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  715. rir_apply_prob=args.rir_apply_prob
  716. if hasattr(args, "rir_apply_prob")
  717. else 1.0,
  718. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  719. noise_apply_prob=args.noise_apply_prob
  720. if hasattr(args, "noise_apply_prob")
  721. else 1.0,
  722. noise_db_range=args.noise_db_range
  723. if hasattr(args, "noise_db_range")
  724. else "13_15",
  725. speech_volume_normalize=args.speech_volume_normalize
  726. if hasattr(args, "rir_scp")
  727. else None,
  728. )
  729. else:
  730. retval = None
  731. assert check_return_type(retval)
  732. return retval
  733. @classmethod
  734. def required_data_names(
  735. cls, train: bool = True, inference: bool = False
  736. ) -> Tuple[str, ...]:
  737. if not inference:
  738. retval = ("speech", "profile", "binary_labels")
  739. else:
  740. # Recognition mode
  741. retval = ("speech")
  742. return retval
  743. @classmethod
  744. def optional_data_names(
  745. cls, train: bool = True, inference: bool = False
  746. ) -> Tuple[str, ...]:
  747. retval = ()
  748. assert check_return_type(retval)
  749. return retval
  750. @classmethod
  751. def build_model(cls, args: argparse.Namespace):
  752. assert check_argument_types()
  753. # 1. frontend
  754. if args.input_size is None or args.frontend == "wav_frontend_mel23":
  755. # Extract features in the model
  756. frontend_class = frontend_choices.get_class(args.frontend)
  757. if args.frontend == 'wav_frontend':
  758. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  759. else:
  760. frontend = frontend_class(**args.frontend_conf)
  761. input_size = frontend.output_size()
  762. else:
  763. # Give features from data-loader
  764. args.frontend = None
  765. args.frontend_conf = {}
  766. frontend = None
  767. input_size = args.input_size
  768. # 2. Encoder
  769. encoder_class = encoder_choices.get_class(args.encoder)
  770. encoder = encoder_class(**args.encoder_conf)
  771. # 3. EncoderDecoderAttractor
  772. encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
  773. encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
  774. # 9. Build model
  775. model_class = model_choices.get_class(args.model)
  776. model = model_class(
  777. frontend=frontend,
  778. encoder=encoder,
  779. encoder_decoder_attractor=encoder_decoder_attractor,
  780. **args.model_conf,
  781. )
  782. # 10. Initialize
  783. if args.init is not None:
  784. initialize(model, args.init)
  785. assert check_return_type(model)
  786. return model
  787. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  788. @classmethod
  789. def build_model_from_file(
  790. cls,
  791. config_file: Union[Path, str] = None,
  792. model_file: Union[Path, str] = None,
  793. cmvn_file: Union[Path, str] = None,
  794. device: str = "cpu",
  795. ):
  796. """Build model from the files.
  797. This method is used for inference or fine-tuning.
  798. Args:
  799. config_file: The yaml file saved when training.
  800. model_file: The model file saved when training.
  801. cmvn_file: The cmvn file for front-end
  802. device: Device type, "cpu", "cuda", or "cuda:N".
  803. """
  804. assert check_argument_types()
  805. if config_file is None:
  806. assert model_file is not None, (
  807. "The argument 'model_file' must be provided "
  808. "if the argument 'config_file' is not specified."
  809. )
  810. config_file = Path(model_file).parent / "config.yaml"
  811. else:
  812. config_file = Path(config_file)
  813. with config_file.open("r", encoding="utf-8") as f:
  814. args = yaml.safe_load(f)
  815. args = argparse.Namespace(**args)
  816. model = cls.build_model(args)
  817. if not isinstance(model, AbsESPnetModel):
  818. raise RuntimeError(
  819. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  820. )
  821. if model_file is not None:
  822. if device == "cuda":
  823. device = f"cuda:{torch.cuda.current_device()}"
  824. checkpoint = torch.load(model_file, map_location=device)
  825. if "state_dict" in checkpoint.keys():
  826. model.load_state_dict(checkpoint["state_dict"])
  827. else:
  828. model.load_state_dict(checkpoint)
  829. model.to(device)
  830. return model, args