diar.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. """
  2. Author: Speech Lab, Alibaba Group, China
  3. SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
  4. https://arxiv.org/abs/2211.10243
  5. TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
  6. https://arxiv.org/abs/2303.05397
  7. """
  8. import argparse
  9. import logging
  10. import os
  11. from pathlib import Path
  12. from typing import Callable
  13. from typing import Collection
  14. from typing import Dict
  15. from typing import List
  16. from typing import Optional
  17. from typing import Tuple
  18. from typing import Union
  19. import numpy as np
  20. import torch
  21. import yaml
  22. from typeguard import check_argument_types
  23. from typeguard import check_return_type
  24. from funasr.datasets.collate_fn import CommonCollateFn
  25. from funasr.datasets.preprocessor import CommonPreprocessor
  26. from funasr.layers.abs_normalize import AbsNormalize
  27. from funasr.layers.global_mvn import GlobalMVN
  28. from funasr.layers.label_aggregation import LabelAggregate
  29. from funasr.layers.utterance_mvn import UtteranceMVN
  30. from funasr.models.e2e_diar_sond import DiarSondModel
  31. from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
  32. from funasr.models.encoder.abs_encoder import AbsEncoder
  33. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  34. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  35. from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
  36. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  37. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  38. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  39. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  40. from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
  41. from funasr.models.encoder.rnn_encoder import RNNEncoder
  42. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  43. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  44. from funasr.models.frontend.abs_frontend import AbsFrontend
  45. from funasr.models.frontend.default import DefaultFrontend
  46. from funasr.models.frontend.fused import FusedFrontends
  47. from funasr.models.frontend.s3prl import S3prlFrontend
  48. from funasr.models.frontend.wav_frontend import WavFrontend
  49. from funasr.models.frontend.wav_frontend import WavFrontendMel23
  50. from funasr.models.frontend.windowing import SlidingWindow
  51. from funasr.models.specaug.abs_specaug import AbsSpecAug
  52. from funasr.models.specaug.specaug import SpecAug
  53. from funasr.models.specaug.specaug import SpecAugLFR
  54. from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
  55. from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
  56. from funasr.tasks.abs_task import AbsTask
  57. from funasr.torch_utils.initialize import initialize
  58. from funasr.train.abs_espnet_model import AbsESPnetModel
  59. from funasr.train.class_choices import ClassChoices
  60. from funasr.train.trainer import Trainer
  61. from funasr.utils.types import float_or_none
  62. from funasr.utils.types import int_or_none
  63. from funasr.utils.types import str2bool
  64. from funasr.utils.types import str_or_none
  65. frontend_choices = ClassChoices(
  66. name="frontend",
  67. classes=dict(
  68. default=DefaultFrontend,
  69. sliding_window=SlidingWindow,
  70. s3prl=S3prlFrontend,
  71. fused=FusedFrontends,
  72. wav_frontend=WavFrontend,
  73. wav_frontend_mel23=WavFrontendMel23,
  74. ),
  75. type_check=AbsFrontend,
  76. default="default",
  77. )
  78. specaug_choices = ClassChoices(
  79. name="specaug",
  80. classes=dict(
  81. specaug=SpecAug,
  82. specaug_lfr=SpecAugLFR,
  83. ),
  84. type_check=AbsSpecAug,
  85. default=None,
  86. optional=True,
  87. )
  88. normalize_choices = ClassChoices(
  89. "normalize",
  90. classes=dict(
  91. global_mvn=GlobalMVN,
  92. utterance_mvn=UtteranceMVN,
  93. ),
  94. type_check=AbsNormalize,
  95. default=None,
  96. optional=True,
  97. )
  98. label_aggregator_choices = ClassChoices(
  99. "label_aggregator",
  100. classes=dict(
  101. label_aggregator=LabelAggregate
  102. ),
  103. type_check=torch.nn.Module,
  104. default=None,
  105. optional=True,
  106. )
  107. model_choices = ClassChoices(
  108. "model",
  109. classes=dict(
  110. sond=DiarSondModel,
  111. eend_ola=DiarEENDOLAModel,
  112. ),
  113. type_check=AbsESPnetModel,
  114. default="sond",
  115. )
  116. encoder_choices = ClassChoices(
  117. "encoder",
  118. classes=dict(
  119. conformer=ConformerEncoder,
  120. transformer=TransformerEncoder,
  121. rnn=RNNEncoder,
  122. sanm=SANMEncoder,
  123. san=SelfAttentionEncoder,
  124. fsmn=FsmnEncoder,
  125. conv=ConvEncoder,
  126. resnet34=ResNet34Diar,
  127. resnet34_sp_l2reg=ResNet34SpL2RegDiar,
  128. sanm_chunk_opt=SANMEncoderChunkOpt,
  129. data2vec_encoder=Data2VecEncoder,
  130. ecapa_tdnn=ECAPA_TDNN,
  131. eend_ola_transformer=EENDOLATransformerEncoder,
  132. ),
  133. type_check=torch.nn.Module,
  134. default="resnet34",
  135. )
  136. speaker_encoder_choices = ClassChoices(
  137. "speaker_encoder",
  138. classes=dict(
  139. conformer=ConformerEncoder,
  140. transformer=TransformerEncoder,
  141. rnn=RNNEncoder,
  142. sanm=SANMEncoder,
  143. san=SelfAttentionEncoder,
  144. fsmn=FsmnEncoder,
  145. conv=ConvEncoder,
  146. sanm_chunk_opt=SANMEncoderChunkOpt,
  147. data2vec_encoder=Data2VecEncoder,
  148. ),
  149. type_check=AbsEncoder,
  150. default=None,
  151. optional=True
  152. )
  153. cd_scorer_choices = ClassChoices(
  154. "cd_scorer",
  155. classes=dict(
  156. san=SelfAttentionEncoder,
  157. ),
  158. type_check=AbsEncoder,
  159. default=None,
  160. optional=True,
  161. )
  162. ci_scorer_choices = ClassChoices(
  163. "ci_scorer",
  164. classes=dict(
  165. dot=DotScorer,
  166. cosine=CosScorer,
  167. conv=ConvEncoder,
  168. ),
  169. type_check=torch.nn.Module,
  170. default=None,
  171. optional=True,
  172. )
  173. # decoder is used for output (e.g. post_net in SOND)
  174. decoder_choices = ClassChoices(
  175. "decoder",
  176. classes=dict(
  177. rnn=RNNEncoder,
  178. fsmn=FsmnEncoder,
  179. ),
  180. type_check=torch.nn.Module,
  181. default="fsmn",
  182. )
  183. # encoder_decoder_attractor is used for EEND-OLA
  184. encoder_decoder_attractor_choices = ClassChoices(
  185. "encoder_decoder_attractor",
  186. classes=dict(
  187. eda=EncoderDecoderAttractor,
  188. ),
  189. type_check=torch.nn.Module,
  190. default="eda",
  191. )
  192. class DiarTask(AbsTask):
  193. # If you need more than 1 optimizer, change this value
  194. num_optimizers: int = 1
  195. # Add variable objects configurations
  196. class_choices_list = [
  197. # --frontend and --frontend_conf
  198. frontend_choices,
  199. # --specaug and --specaug_conf
  200. specaug_choices,
  201. # --normalize and --normalize_conf
  202. normalize_choices,
  203. # --label_aggregator and --label_aggregator_conf
  204. label_aggregator_choices,
  205. # --model and --model_conf
  206. model_choices,
  207. # --encoder and --encoder_conf
  208. encoder_choices,
  209. # --speaker_encoder and --speaker_encoder_conf
  210. speaker_encoder_choices,
  211. # --cd_scorer and cd_scorer_conf
  212. cd_scorer_choices,
  213. # --ci_scorer and ci_scorer_conf
  214. ci_scorer_choices,
  215. # --decoder and --decoder_conf
  216. decoder_choices,
  217. ]
  218. # If you need to modify train() or eval() procedures, change Trainer class here
  219. trainer = Trainer
  220. @classmethod
  221. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  222. group = parser.add_argument_group(description="Task related")
  223. # NOTE(kamo): add_arguments(..., required=True) can't be used
  224. # to provide --print_config mode. Instead of it, do as
  225. # required = parser.get_default("required")
  226. # required += ["token_list"]
  227. group.add_argument(
  228. "--token_list",
  229. type=str_or_none,
  230. default=None,
  231. help="A text mapping int-id to token",
  232. )
  233. group.add_argument(
  234. "--split_with_space",
  235. type=str2bool,
  236. default=True,
  237. help="whether to split text using <space>",
  238. )
  239. group.add_argument(
  240. "--seg_dict_file",
  241. type=str,
  242. default=None,
  243. help="seg_dict_file for text processing",
  244. )
  245. group.add_argument(
  246. "--init",
  247. type=lambda x: str_or_none(x.lower()),
  248. default=None,
  249. help="The initialization method",
  250. choices=[
  251. "chainer",
  252. "xavier_uniform",
  253. "xavier_normal",
  254. "kaiming_uniform",
  255. "kaiming_normal",
  256. None,
  257. ],
  258. )
  259. group.add_argument(
  260. "--input_size",
  261. type=int_or_none,
  262. default=None,
  263. help="The number of input dimension of the feature",
  264. )
  265. group = parser.add_argument_group(description="Preprocess related")
  266. group.add_argument(
  267. "--use_preprocessor",
  268. type=str2bool,
  269. default=True,
  270. help="Apply preprocessing to data or not",
  271. )
  272. group.add_argument(
  273. "--token_type",
  274. type=str,
  275. default="char",
  276. choices=["char"],
  277. help="The text will be tokenized in the specified level token",
  278. )
  279. parser.add_argument(
  280. "--speech_volume_normalize",
  281. type=float_or_none,
  282. default=None,
  283. help="Scale the maximum amplitude to the given value.",
  284. )
  285. parser.add_argument(
  286. "--rir_scp",
  287. type=str_or_none,
  288. default=None,
  289. help="The file path of rir scp file.",
  290. )
  291. parser.add_argument(
  292. "--rir_apply_prob",
  293. type=float,
  294. default=1.0,
  295. help="THe probability for applying RIR convolution.",
  296. )
  297. parser.add_argument(
  298. "--cmvn_file",
  299. type=str_or_none,
  300. default=None,
  301. help="The file path of noise scp file.",
  302. )
  303. parser.add_argument(
  304. "--noise_scp",
  305. type=str_or_none,
  306. default=None,
  307. help="The file path of noise scp file.",
  308. )
  309. parser.add_argument(
  310. "--noise_apply_prob",
  311. type=float,
  312. default=1.0,
  313. help="The probability applying Noise adding.",
  314. )
  315. parser.add_argument(
  316. "--noise_db_range",
  317. type=str,
  318. default="13_15",
  319. help="The range of noise decibel level.",
  320. )
  321. for class_choices in cls.class_choices_list:
  322. # Append --<name> and --<name>_conf.
  323. # e.g. --encoder and --encoder_conf
  324. class_choices.add_arguments(group)
  325. @classmethod
  326. def build_collate_fn(
  327. cls, args: argparse.Namespace, train: bool
  328. ) -> Callable[
  329. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  330. Tuple[List[str], Dict[str, torch.Tensor]],
  331. ]:
  332. assert check_argument_types()
  333. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  334. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  335. @classmethod
  336. def build_preprocess_fn(
  337. cls, args: argparse.Namespace, train: bool
  338. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  339. assert check_argument_types()
  340. if args.use_preprocessor:
  341. retval = CommonPreprocessor(
  342. train=train,
  343. token_type=args.token_type,
  344. token_list=args.token_list,
  345. bpemodel=None,
  346. non_linguistic_symbols=None,
  347. text_cleaner=None,
  348. g2p_type=None,
  349. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  350. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  351. # NOTE(kamo): Check attribute existence for backward compatibility
  352. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  353. rir_apply_prob=args.rir_apply_prob
  354. if hasattr(args, "rir_apply_prob")
  355. else 1.0,
  356. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  357. noise_apply_prob=args.noise_apply_prob
  358. if hasattr(args, "noise_apply_prob")
  359. else 1.0,
  360. noise_db_range=args.noise_db_range
  361. if hasattr(args, "noise_db_range")
  362. else "13_15",
  363. speech_volume_normalize=args.speech_volume_normalize
  364. if hasattr(args, "rir_scp")
  365. else None,
  366. )
  367. else:
  368. retval = None
  369. assert check_return_type(retval)
  370. return retval
  371. @classmethod
  372. def required_data_names(
  373. cls, train: bool = True, inference: bool = False
  374. ) -> Tuple[str, ...]:
  375. if not inference:
  376. retval = ("speech", "profile", "binary_labels")
  377. else:
  378. # Recognition mode
  379. retval = ("speech", "profile")
  380. return retval
  381. @classmethod
  382. def optional_data_names(
  383. cls, train: bool = True, inference: bool = False
  384. ) -> Tuple[str, ...]:
  385. retval = ()
  386. assert check_return_type(retval)
  387. return retval
  388. @classmethod
  389. def build_model(cls, args: argparse.Namespace):
  390. assert check_argument_types()
  391. if isinstance(args.token_list, str):
  392. with open(args.token_list, encoding="utf-8") as f:
  393. token_list = [line.rstrip() for line in f]
  394. # Overwriting token_list to keep it as "portable".
  395. args.token_list = list(token_list)
  396. elif isinstance(args.token_list, (tuple, list)):
  397. token_list = list(args.token_list)
  398. else:
  399. raise RuntimeError("token_list must be str or list")
  400. vocab_size = len(token_list)
  401. logging.info(f"Vocabulary size: {vocab_size}")
  402. # 1. frontend
  403. if args.input_size is None:
  404. # Extract features in the model
  405. frontend_class = frontend_choices.get_class(args.frontend)
  406. if args.frontend == 'wav_frontend':
  407. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  408. else:
  409. frontend = frontend_class(**args.frontend_conf)
  410. input_size = frontend.output_size()
  411. else:
  412. # Give features from data-loader
  413. args.frontend = None
  414. args.frontend_conf = {}
  415. frontend = None
  416. input_size = args.input_size
  417. # 2. Data augmentation for spectrogram
  418. if args.specaug is not None:
  419. specaug_class = specaug_choices.get_class(args.specaug)
  420. specaug = specaug_class(**args.specaug_conf)
  421. else:
  422. specaug = None
  423. # 3. Normalization layer
  424. if args.normalize is not None:
  425. normalize_class = normalize_choices.get_class(args.normalize)
  426. normalize = normalize_class(**args.normalize_conf)
  427. else:
  428. normalize = None
  429. # 4. Encoder
  430. encoder_class = encoder_choices.get_class(args.encoder)
  431. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  432. # 5. speaker encoder
  433. if getattr(args, "speaker_encoder", None) is not None:
  434. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  435. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  436. else:
  437. speaker_encoder = None
  438. # 6. CI & CD scorer
  439. if getattr(args, "ci_scorer", None) is not None:
  440. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  441. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  442. else:
  443. ci_scorer = None
  444. if getattr(args, "cd_scorer", None) is not None:
  445. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  446. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  447. else:
  448. cd_scorer = None
  449. # 7. Decoder
  450. decoder_class = decoder_choices.get_class(args.decoder)
  451. decoder = decoder_class(**args.decoder_conf)
  452. if getattr(args, "label_aggregator", None) is not None:
  453. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  454. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  455. else:
  456. label_aggregator = None
  457. # 9. Build model
  458. model_class = model_choices.get_class(args.model)
  459. model = model_class(
  460. vocab_size=vocab_size,
  461. frontend=frontend,
  462. specaug=specaug,
  463. normalize=normalize,
  464. label_aggregator=label_aggregator,
  465. encoder=encoder,
  466. speaker_encoder=speaker_encoder,
  467. ci_scorer=ci_scorer,
  468. cd_scorer=cd_scorer,
  469. decoder=decoder,
  470. token_list=token_list,
  471. **args.model_conf,
  472. )
  473. # 10. Initialize
  474. if args.init is not None:
  475. initialize(model, args.init)
  476. assert check_return_type(model)
  477. return model
  478. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  479. @classmethod
  480. def build_model_from_file(
  481. cls,
  482. config_file: Union[Path, str] = None,
  483. model_file: Union[Path, str] = None,
  484. cmvn_file: Union[Path, str] = None,
  485. device: Union[str, torch.device] = "cpu",
  486. ):
  487. """Build model from the files.
  488. This method is used for inference or fine-tuning.
  489. Args:
  490. config_file: The yaml file saved when training.
  491. model_file: The model file saved when training.
  492. cmvn_file: The cmvn file for front-end
  493. device: Device type, "cpu", "cuda", or "cuda:N".
  494. """
  495. assert check_argument_types()
  496. if config_file is None:
  497. assert model_file is not None, (
  498. "The argument 'model_file' must be provided "
  499. "if the argument 'config_file' is not specified."
  500. )
  501. config_file = Path(model_file).parent / "config.yaml"
  502. else:
  503. config_file = Path(config_file)
  504. with config_file.open("r", encoding="utf-8") as f:
  505. args = yaml.safe_load(f)
  506. if cmvn_file is not None:
  507. args["cmvn_file"] = cmvn_file
  508. args = argparse.Namespace(**args)
  509. model = cls.build_model(args)
  510. if not isinstance(model, AbsESPnetModel):
  511. raise RuntimeError(
  512. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  513. )
  514. model.to(device)
  515. model_dict = dict()
  516. model_name_pth = None
  517. if model_file is not None:
  518. logging.info("model_file is {}".format(model_file))
  519. if device == "cuda":
  520. device = f"cuda:{torch.cuda.current_device()}"
  521. model_dir = os.path.dirname(model_file)
  522. model_name = os.path.basename(model_file)
  523. if "model.ckpt-" in model_name or ".bin" in model_name:
  524. if ".bin" in model_name:
  525. model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
  526. else:
  527. model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
  528. if os.path.exists(model_name_pth):
  529. logging.info("model_file is load from pth: {}".format(model_name_pth))
  530. model_dict = torch.load(model_name_pth, map_location=device)
  531. else:
  532. model_dict = cls.convert_tf2torch(model, model_file)
  533. model.load_state_dict(model_dict)
  534. else:
  535. model_dict = torch.load(model_file, map_location=device)
  536. model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
  537. model.load_state_dict(model_dict)
  538. if model_name_pth is not None and not os.path.exists(model_name_pth):
  539. torch.save(model_dict, model_name_pth)
  540. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  541. return model, args
  542. @classmethod
  543. def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
  544. from collections import OrderedDict
  545. new_dict = OrderedDict()
  546. for key, value in src_dict.items():
  547. if key in dest_dict:
  548. new_dict[key] = value
  549. else:
  550. logging.info("{} is no longer needed in this model.".format(key))
  551. for key, value in dest_dict.items():
  552. if key not in new_dict:
  553. logging.warning("{} is missed in checkpoint.".format(key))
  554. return new_dict
  555. @classmethod
  556. def convert_tf2torch(
  557. cls,
  558. model,
  559. ckpt,
  560. ):
  561. logging.info("start convert tf model to torch model")
  562. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  563. var_dict_tf = load_tf_dict(ckpt)
  564. var_dict_torch = model.state_dict()
  565. var_dict_torch_update = dict()
  566. # speech encoder
  567. if model.encoder is not None:
  568. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  569. var_dict_torch_update.update(var_dict_torch_update_local)
  570. # speaker encoder
  571. if model.speaker_encoder is not None:
  572. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  573. var_dict_torch_update.update(var_dict_torch_update_local)
  574. # cd scorer
  575. if model.cd_scorer is not None:
  576. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  577. var_dict_torch_update.update(var_dict_torch_update_local)
  578. # ci scorer
  579. if model.ci_scorer is not None:
  580. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  581. var_dict_torch_update.update(var_dict_torch_update_local)
  582. # decoder
  583. if model.decoder is not None:
  584. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  585. var_dict_torch_update.update(var_dict_torch_update_local)
  586. return var_dict_torch_update
  587. class EENDOLADiarTask(AbsTask):
  588. # If you need more than 1 optimizer, change this value
  589. num_optimizers: int = 1
  590. # Add variable objects configurations
  591. class_choices_list = [
  592. # --frontend and --frontend_conf
  593. frontend_choices,
  594. # --specaug and --specaug_conf
  595. model_choices,
  596. # --encoder and --encoder_conf
  597. encoder_choices,
  598. # --speaker_encoder and --speaker_encoder_conf
  599. encoder_decoder_attractor_choices,
  600. ]
  601. # If you need to modify train() or eval() procedures, change Trainer class here
  602. trainer = Trainer
  603. @classmethod
  604. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  605. group = parser.add_argument_group(description="Task related")
  606. # NOTE(kamo): add_arguments(..., required=True) can't be used
  607. # to provide --print_config mode. Instead of it, do as
  608. # required = parser.get_default("required")
  609. # required += ["token_list"]
  610. group.add_argument(
  611. "--token_list",
  612. type=str_or_none,
  613. default=None,
  614. help="A text mapping int-id to token",
  615. )
  616. group.add_argument(
  617. "--split_with_space",
  618. type=str2bool,
  619. default=True,
  620. help="whether to split text using <space>",
  621. )
  622. group.add_argument(
  623. "--seg_dict_file",
  624. type=str,
  625. default=None,
  626. help="seg_dict_file for text processing",
  627. )
  628. group.add_argument(
  629. "--init",
  630. type=lambda x: str_or_none(x.lower()),
  631. default=None,
  632. help="The initialization method",
  633. choices=[
  634. "chainer",
  635. "xavier_uniform",
  636. "xavier_normal",
  637. "kaiming_uniform",
  638. "kaiming_normal",
  639. None,
  640. ],
  641. )
  642. group.add_argument(
  643. "--input_size",
  644. type=int_or_none,
  645. default=None,
  646. help="The number of input dimension of the feature",
  647. )
  648. group = parser.add_argument_group(description="Preprocess related")
  649. group.add_argument(
  650. "--use_preprocessor",
  651. type=str2bool,
  652. default=True,
  653. help="Apply preprocessing to data or not",
  654. )
  655. group.add_argument(
  656. "--token_type",
  657. type=str,
  658. default="char",
  659. choices=["char"],
  660. help="The text will be tokenized in the specified level token",
  661. )
  662. parser.add_argument(
  663. "--speech_volume_normalize",
  664. type=float_or_none,
  665. default=None,
  666. help="Scale the maximum amplitude to the given value.",
  667. )
  668. parser.add_argument(
  669. "--rir_scp",
  670. type=str_or_none,
  671. default=None,
  672. help="The file path of rir scp file.",
  673. )
  674. parser.add_argument(
  675. "--rir_apply_prob",
  676. type=float,
  677. default=1.0,
  678. help="THe probability for applying RIR convolution.",
  679. )
  680. parser.add_argument(
  681. "--cmvn_file",
  682. type=str_or_none,
  683. default=None,
  684. help="The file path of noise scp file.",
  685. )
  686. parser.add_argument(
  687. "--noise_scp",
  688. type=str_or_none,
  689. default=None,
  690. help="The file path of noise scp file.",
  691. )
  692. parser.add_argument(
  693. "--noise_apply_prob",
  694. type=float,
  695. default=1.0,
  696. help="The probability applying Noise adding.",
  697. )
  698. parser.add_argument(
  699. "--noise_db_range",
  700. type=str,
  701. default="13_15",
  702. help="The range of noise decibel level.",
  703. )
  704. for class_choices in cls.class_choices_list:
  705. # Append --<name> and --<name>_conf.
  706. # e.g. --encoder and --encoder_conf
  707. class_choices.add_arguments(group)
  708. @classmethod
  709. def build_collate_fn(
  710. cls, args: argparse.Namespace, train: bool
  711. ) -> Callable[
  712. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  713. Tuple[List[str], Dict[str, torch.Tensor]],
  714. ]:
  715. assert check_argument_types()
  716. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  717. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  718. @classmethod
  719. def build_preprocess_fn(
  720. cls, args: argparse.Namespace, train: bool
  721. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  722. assert check_argument_types()
  723. # if args.use_preprocessor:
  724. # retval = CommonPreprocessor(
  725. # train=train,
  726. # token_type=args.token_type,
  727. # token_list=args.token_list,
  728. # bpemodel=None,
  729. # non_linguistic_symbols=None,
  730. # text_cleaner=None,
  731. # g2p_type=None,
  732. # split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  733. # seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  734. # # NOTE(kamo): Check attribute existence for backward compatibility
  735. # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  736. # rir_apply_prob=args.rir_apply_prob
  737. # if hasattr(args, "rir_apply_prob")
  738. # else 1.0,
  739. # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  740. # noise_apply_prob=args.noise_apply_prob
  741. # if hasattr(args, "noise_apply_prob")
  742. # else 1.0,
  743. # noise_db_range=args.noise_db_range
  744. # if hasattr(args, "noise_db_range")
  745. # else "13_15",
  746. # speech_volume_normalize=args.speech_volume_normalize
  747. # if hasattr(args, "rir_scp")
  748. # else None,
  749. # )
  750. # else:
  751. # retval = None
  752. # assert check_return_type(retval)
  753. return None
  754. @classmethod
  755. def required_data_names(
  756. cls, train: bool = True, inference: bool = False
  757. ) -> Tuple[str, ...]:
  758. if not inference:
  759. retval = ("speech", )
  760. else:
  761. # Recognition mode
  762. retval = ("speech", )
  763. return retval
  764. @classmethod
  765. def optional_data_names(
  766. cls, train: bool = True, inference: bool = False
  767. ) -> Tuple[str, ...]:
  768. retval = ()
  769. assert check_return_type(retval)
  770. return retval
  771. @classmethod
  772. def build_model(cls, args: argparse.Namespace):
  773. assert check_argument_types()
  774. # 1. frontend
  775. if args.input_size is None or args.frontend == "wav_frontend_mel23":
  776. # Extract features in the model
  777. frontend_class = frontend_choices.get_class(args.frontend)
  778. if args.frontend == 'wav_frontend':
  779. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  780. else:
  781. frontend = frontend_class(**args.frontend_conf)
  782. input_size = frontend.output_size()
  783. else:
  784. # Give features from data-loader
  785. args.frontend = None
  786. args.frontend_conf = {}
  787. frontend = None
  788. input_size = args.input_size
  789. # 2. Encoder
  790. encoder_class = encoder_choices.get_class(args.encoder)
  791. encoder = encoder_class(**args.encoder_conf)
  792. # 3. EncoderDecoderAttractor
  793. encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
  794. encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
  795. # 9. Build model
  796. model_class = model_choices.get_class(args.model)
  797. model = model_class(
  798. frontend=frontend,
  799. encoder=encoder,
  800. encoder_decoder_attractor=encoder_decoder_attractor,
  801. **args.model_conf,
  802. )
  803. # 10. Initialize
  804. if args.init is not None:
  805. initialize(model, args.init)
  806. assert check_return_type(model)
  807. return model
  808. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  809. @classmethod
  810. def build_model_from_file(
  811. cls,
  812. config_file: Union[Path, str] = None,
  813. model_file: Union[Path, str] = None,
  814. cmvn_file: Union[Path, str] = None,
  815. device: str = "cpu",
  816. ):
  817. """Build model from the files.
  818. This method is used for inference or fine-tuning.
  819. Args:
  820. config_file: The yaml file saved when training.
  821. model_file: The model file saved when training.
  822. cmvn_file: The cmvn file for front-end
  823. device: Device type, "cpu", "cuda", or "cuda:N".
  824. """
  825. assert check_argument_types()
  826. if config_file is None:
  827. assert model_file is not None, (
  828. "The argument 'model_file' must be provided "
  829. "if the argument 'config_file' is not specified."
  830. )
  831. config_file = Path(model_file).parent / "config.yaml"
  832. else:
  833. config_file = Path(config_file)
  834. with config_file.open("r", encoding="utf-8") as f:
  835. args = yaml.safe_load(f)
  836. args = argparse.Namespace(**args)
  837. model = cls.build_model(args)
  838. if not isinstance(model, AbsESPnetModel):
  839. raise RuntimeError(
  840. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  841. )
  842. if model_file is not None:
  843. if device == "cuda":
  844. device = f"cuda:{torch.cuda.current_device()}"
  845. checkpoint = torch.load(model_file, map_location=device)
  846. if "state_dict" in checkpoint.keys():
  847. model.load_state_dict(checkpoint["state_dict"])
  848. else:
  849. model.load_state_dict(checkpoint)
  850. model.to(device)
  851. return model, args