diar.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.layers.label_aggregation import LabelAggregate
  23. from funasr.models.ctc import CTC
  24. from funasr.models.encoder.resnet34_encoder import ResNet34Diar
  25. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  26. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  27. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  28. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  29. from funasr.models.e2e_diar_sond import DiarSondModel
  30. from funasr.models.encoder.abs_encoder import AbsEncoder
  31. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  32. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  33. from funasr.models.encoder.rnn_encoder import RNNEncoder
  34. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  35. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  36. from funasr.models.frontend.abs_frontend import AbsFrontend
  37. from funasr.models.frontend.default import DefaultFrontend
  38. from funasr.models.frontend.fused import FusedFrontends
  39. from funasr.models.frontend.s3prl import S3prlFrontend
  40. from funasr.models.frontend.wav_frontend import WavFrontend
  41. from funasr.models.frontend.windowing import SlidingWindow
  42. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  43. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  44. HuggingFaceTransformersPostEncoder, # noqa: H301
  45. )
  46. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  47. from funasr.models.preencoder.linear import LinearProjection
  48. from funasr.models.preencoder.sinc import LightweightSincConvs
  49. from funasr.models.specaug.abs_specaug import AbsSpecAug
  50. from funasr.models.specaug.specaug import SpecAug
  51. from funasr.models.specaug.specaug import SpecAugLFR
  52. from funasr.tasks.abs_task import AbsTask
  53. from funasr.torch_utils.initialize import initialize
  54. from funasr.train.abs_espnet_model import AbsESPnetModel
  55. from funasr.train.class_choices import ClassChoices
  56. from funasr.train.trainer import Trainer
  57. from funasr.utils.types import float_or_none
  58. from funasr.utils.types import int_or_none
  59. from funasr.utils.types import str2bool
  60. from funasr.utils.types import str_or_none
  61. frontend_choices = ClassChoices(
  62. name="frontend",
  63. classes=dict(
  64. default=DefaultFrontend,
  65. sliding_window=SlidingWindow,
  66. s3prl=S3prlFrontend,
  67. fused=FusedFrontends,
  68. wav_frontend=WavFrontend,
  69. ),
  70. type_check=AbsFrontend,
  71. default="default",
  72. )
  73. specaug_choices = ClassChoices(
  74. name="specaug",
  75. classes=dict(
  76. specaug=SpecAug,
  77. specaug_lfr=SpecAugLFR,
  78. ),
  79. type_check=AbsSpecAug,
  80. default=None,
  81. optional=True,
  82. )
  83. normalize_choices = ClassChoices(
  84. "normalize",
  85. classes=dict(
  86. global_mvn=GlobalMVN,
  87. utterance_mvn=UtteranceMVN,
  88. ),
  89. type_check=AbsNormalize,
  90. default=None,
  91. optional=True,
  92. )
  93. label_aggregator_choices = ClassChoices(
  94. "label_aggregator",
  95. classes=dict(
  96. label_aggregator=LabelAggregate
  97. ),
  98. type_check=torch.nn.Module,
  99. default=None,
  100. optional=True,
  101. )
  102. model_choices = ClassChoices(
  103. "model",
  104. classes=dict(
  105. sond=DiarSondModel,
  106. ),
  107. type_check=AbsESPnetModel,
  108. default="sond",
  109. )
  110. encoder_choices = ClassChoices(
  111. "encoder",
  112. classes=dict(
  113. conformer=ConformerEncoder,
  114. transformer=TransformerEncoder,
  115. rnn=RNNEncoder,
  116. sanm=SANMEncoder,
  117. san=SelfAttentionEncoder,
  118. fsmn=FsmnEncoder,
  119. conv=ConvEncoder,
  120. resnet34=ResNet34Diar,
  121. sanm_chunk_opt=SANMEncoderChunkOpt,
  122. data2vec_encoder=Data2VecEncoder,
  123. ),
  124. type_check=AbsEncoder,
  125. default="resnet34",
  126. )
  127. speaker_encoder_choices = ClassChoices(
  128. "speaker_encoder",
  129. classes=dict(
  130. conformer=ConformerEncoder,
  131. transformer=TransformerEncoder,
  132. rnn=RNNEncoder,
  133. sanm=SANMEncoder,
  134. san=SelfAttentionEncoder,
  135. fsmn=FsmnEncoder,
  136. conv=ConvEncoder,
  137. sanm_chunk_opt=SANMEncoderChunkOpt,
  138. data2vec_encoder=Data2VecEncoder,
  139. ),
  140. type_check=AbsEncoder,
  141. default=None,
  142. optional=True
  143. )
  144. cd_scorer_choices = ClassChoices(
  145. "cd_scorer",
  146. classes=dict(
  147. san=SelfAttentionEncoder,
  148. ),
  149. type_check=AbsEncoder,
  150. default=None,
  151. optional=True,
  152. )
  153. ci_scorer_choices = ClassChoices(
  154. "ci_scorer",
  155. classes=dict(
  156. dot=DotScorer,
  157. cosine=CosScorer,
  158. ),
  159. type_check=torch.nn.Module,
  160. default=None,
  161. optional=True,
  162. )
  163. # decoder is used for output (e.g. post_net in SOND)
  164. decoder_choices = ClassChoices(
  165. "decoder",
  166. classes=dict(
  167. rnn=RNNEncoder,
  168. fsmn=FsmnEncoder,
  169. ),
  170. type_check=torch.nn.Module,
  171. default="fsmn",
  172. )
  173. class DiarTask(AbsTask):
  174. # If you need more than 1 optimizer, change this value
  175. num_optimizers: int = 1
  176. # Add variable objects configurations
  177. class_choices_list = [
  178. # --frontend and --frontend_conf
  179. frontend_choices,
  180. # --specaug and --specaug_conf
  181. specaug_choices,
  182. # --normalize and --normalize_conf
  183. normalize_choices,
  184. # --model and --model_conf
  185. model_choices,
  186. # --encoder and --encoder_conf
  187. encoder_choices,
  188. # --speaker_encoder and --speaker_encoder_conf
  189. speaker_encoder_choices,
  190. # --cd_scorer and cd_scorer_conf
  191. cd_scorer_choices,
  192. # --ci_scorer and ci_scorer_conf
  193. ci_scorer_choices,
  194. # --decoder and --decoder_conf
  195. decoder_choices,
  196. ]
  197. # If you need to modify train() or eval() procedures, change Trainer class here
  198. trainer = Trainer
  199. @classmethod
  200. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  201. group = parser.add_argument_group(description="Task related")
  202. # NOTE(kamo): add_arguments(..., required=True) can't be used
  203. # to provide --print_config mode. Instead of it, do as
  204. # required = parser.get_default("required")
  205. # required += ["token_list"]
  206. group.add_argument(
  207. "--token_list",
  208. type=str_or_none,
  209. default=None,
  210. help="A text mapping int-id to token",
  211. )
  212. group.add_argument(
  213. "--split_with_space",
  214. type=str2bool,
  215. default=True,
  216. help="whether to split text using <space>",
  217. )
  218. group.add_argument(
  219. "--seg_dict_file",
  220. type=str,
  221. default=None,
  222. help="seg_dict_file for text processing",
  223. )
  224. group.add_argument(
  225. "--init",
  226. type=lambda x: str_or_none(x.lower()),
  227. default=None,
  228. help="The initialization method",
  229. choices=[
  230. "chainer",
  231. "xavier_uniform",
  232. "xavier_normal",
  233. "kaiming_uniform",
  234. "kaiming_normal",
  235. None,
  236. ],
  237. )
  238. group.add_argument(
  239. "--input_size",
  240. type=int_or_none,
  241. default=None,
  242. help="The number of input dimension of the feature",
  243. )
  244. group = parser.add_argument_group(description="Preprocess related")
  245. group.add_argument(
  246. "--use_preprocessor",
  247. type=str2bool,
  248. default=True,
  249. help="Apply preprocessing to data or not",
  250. )
  251. group.add_argument(
  252. "--token_type",
  253. type=str,
  254. default="char",
  255. choices=["char"],
  256. help="The text will be tokenized in the specified level token",
  257. )
  258. parser.add_argument(
  259. "--speech_volume_normalize",
  260. type=float_or_none,
  261. default=None,
  262. help="Scale the maximum amplitude to the given value.",
  263. )
  264. parser.add_argument(
  265. "--rir_scp",
  266. type=str_or_none,
  267. default=None,
  268. help="The file path of rir scp file.",
  269. )
  270. parser.add_argument(
  271. "--rir_apply_prob",
  272. type=float,
  273. default=1.0,
  274. help="THe probability for applying RIR convolution.",
  275. )
  276. parser.add_argument(
  277. "--cmvn_file",
  278. type=str_or_none,
  279. default=None,
  280. help="The file path of noise scp file.",
  281. )
  282. parser.add_argument(
  283. "--noise_scp",
  284. type=str_or_none,
  285. default=None,
  286. help="The file path of noise scp file.",
  287. )
  288. parser.add_argument(
  289. "--noise_apply_prob",
  290. type=float,
  291. default=1.0,
  292. help="The probability applying Noise adding.",
  293. )
  294. parser.add_argument(
  295. "--noise_db_range",
  296. type=str,
  297. default="13_15",
  298. help="The range of noise decibel level.",
  299. )
  300. for class_choices in cls.class_choices_list:
  301. # Append --<name> and --<name>_conf.
  302. # e.g. --encoder and --encoder_conf
  303. class_choices.add_arguments(group)
  304. @classmethod
  305. def build_collate_fn(
  306. cls, args: argparse.Namespace, train: bool
  307. ) -> Callable[
  308. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  309. Tuple[List[str], Dict[str, torch.Tensor]],
  310. ]:
  311. assert check_argument_types()
  312. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  313. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  314. @classmethod
  315. def build_preprocess_fn(
  316. cls, args: argparse.Namespace, train: bool
  317. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  318. assert check_argument_types()
  319. if args.use_preprocessor:
  320. retval = CommonPreprocessor(
  321. train=train,
  322. token_type=args.token_type,
  323. token_list=args.token_list,
  324. bpemodel=None,
  325. non_linguistic_symbols=None,
  326. text_cleaner=None,
  327. g2p_type=None,
  328. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  329. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  330. # NOTE(kamo): Check attribute existence for backward compatibility
  331. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  332. rir_apply_prob=args.rir_apply_prob
  333. if hasattr(args, "rir_apply_prob")
  334. else 1.0,
  335. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  336. noise_apply_prob=args.noise_apply_prob
  337. if hasattr(args, "noise_apply_prob")
  338. else 1.0,
  339. noise_db_range=args.noise_db_range
  340. if hasattr(args, "noise_db_range")
  341. else "13_15",
  342. speech_volume_normalize=args.speech_volume_normalize
  343. if hasattr(args, "rir_scp")
  344. else None,
  345. )
  346. else:
  347. retval = None
  348. assert check_return_type(retval)
  349. return retval
  350. @classmethod
  351. def required_data_names(
  352. cls, train: bool = True, inference: bool = False
  353. ) -> Tuple[str, ...]:
  354. if not inference:
  355. retval = ("speech", "profile", "label")
  356. else:
  357. # Recognition mode
  358. retval = ("speech", "profile")
  359. return retval
  360. @classmethod
  361. def optional_data_names(
  362. cls, train: bool = True, inference: bool = False
  363. ) -> Tuple[str, ...]:
  364. retval = ()
  365. assert check_return_type(retval)
  366. return retval
  367. @classmethod
  368. def build_model(cls, args: argparse.Namespace):
  369. assert check_argument_types()
  370. if isinstance(args.token_list, str):
  371. with open(args.token_list, encoding="utf-8") as f:
  372. token_list = [line.rstrip() for line in f]
  373. # Overwriting token_list to keep it as "portable".
  374. args.token_list = list(token_list)
  375. elif isinstance(args.token_list, (tuple, list)):
  376. token_list = list(args.token_list)
  377. else:
  378. raise RuntimeError("token_list must be str or list")
  379. vocab_size = len(token_list)
  380. logging.info(f"Vocabulary size: {vocab_size}")
  381. # 1. frontend
  382. if args.input_size is None:
  383. # Extract features in the model
  384. frontend_class = frontend_choices.get_class(args.frontend)
  385. if args.frontend == 'wav_frontend':
  386. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  387. else:
  388. frontend = frontend_class(**args.frontend_conf)
  389. input_size = frontend.output_size()
  390. else:
  391. # Give features from data-loader
  392. args.frontend = None
  393. args.frontend_conf = {}
  394. frontend = None
  395. input_size = args.input_size
  396. # 2. Data augmentation for spectrogram
  397. if args.specaug is not None:
  398. specaug_class = specaug_choices.get_class(args.specaug)
  399. specaug = specaug_class(**args.specaug_conf)
  400. else:
  401. specaug = None
  402. # 3. Normalization layer
  403. if args.normalize is not None:
  404. normalize_class = normalize_choices.get_class(args.normalize)
  405. normalize = normalize_class(**args.normalize_conf)
  406. else:
  407. normalize = None
  408. # 4. Encoder
  409. encoder_class = encoder_choices.get_class(args.encoder)
  410. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  411. # 5. speaker encoder
  412. if getattr(args, "speaker_encoder", None) is not None:
  413. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  414. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  415. else:
  416. speaker_encoder = None
  417. # 6. CI & CD scorer
  418. if getattr(args, "ci_scorer", None) is not None:
  419. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  420. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  421. else:
  422. ci_scorer = None
  423. if getattr(args, "cd_scorer", None) is not None:
  424. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  425. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  426. else:
  427. cd_scorer = None
  428. # 7. Decoder
  429. decoder_class = decoder_choices.get_class(args.decoder)
  430. decoder = decoder_class(**args.decoder_conf)
  431. if getattr(args, "label_aggregator", None) is not None:
  432. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  433. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  434. else:
  435. label_aggregator = None
  436. # 9. Build model
  437. model_class = model_choices.get_class(args.model)
  438. model = model_class(
  439. vocab_size=vocab_size,
  440. frontend=frontend,
  441. specaug=specaug,
  442. normalize=normalize,
  443. label_aggregator=label_aggregator,
  444. encoder=encoder,
  445. speaker_encoder=speaker_encoder,
  446. ci_scorer=ci_scorer,
  447. cd_scorer=cd_scorer,
  448. decoder=decoder,
  449. token_list=token_list,
  450. **args.model_conf,
  451. )
  452. # 10. Initialize
  453. if args.init is not None:
  454. initialize(model, args.init)
  455. assert check_return_type(model)
  456. return model
  457. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  458. @classmethod
  459. def build_model_from_file(
  460. cls,
  461. config_file: Union[Path, str] = None,
  462. model_file: Union[Path, str] = None,
  463. cmvn_file: Union[Path, str] = None,
  464. device: str = "cpu",
  465. ):
  466. """Build model from the files.
  467. This method is used for inference or fine-tuning.
  468. Args:
  469. config_file: The yaml file saved when training.
  470. model_file: The model file saved when training.
  471. cmvn_file: The cmvn file for front-end
  472. device: Device type, "cpu", "cuda", or "cuda:N".
  473. """
  474. assert check_argument_types()
  475. if config_file is None:
  476. assert model_file is not None, (
  477. "The argument 'model_file' must be provided "
  478. "if the argument 'config_file' is not specified."
  479. )
  480. config_file = Path(model_file).parent / "config.yaml"
  481. else:
  482. config_file = Path(config_file)
  483. with config_file.open("r", encoding="utf-8") as f:
  484. args = yaml.safe_load(f)
  485. if cmvn_file is not None:
  486. args["cmvn_file"] = cmvn_file
  487. args = argparse.Namespace(**args)
  488. model = cls.build_model(args)
  489. if not isinstance(model, AbsESPnetModel):
  490. raise RuntimeError(
  491. f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
  492. )
  493. model.to(device)
  494. model_dict = dict()
  495. model_name_pth = None
  496. if model_file is not None:
  497. logging.info("model_file is {}".format(model_file))
  498. if device == "cuda":
  499. device = f"cuda:{torch.cuda.current_device()}"
  500. model_dir = os.path.dirname(model_file)
  501. model_name = os.path.basename(model_file)
  502. if "model.ckpt-" in model_name or ".bin" in model_name:
  503. if ".bin" in model_name:
  504. model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
  505. else:
  506. model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
  507. if os.path.exists(model_name_pth):
  508. logging.info("model_file is load from pth: {}".format(model_name_pth))
  509. model_dict = torch.load(model_name_pth, map_location=device)
  510. else:
  511. model_dict = cls.convert_tf2torch(model, model_file)
  512. model.load_state_dict(model_dict)
  513. else:
  514. model_dict = torch.load(model_file, map_location=device)
  515. model.load_state_dict(model_dict)
  516. if model_name_pth is not None and not os.path.exists(model_name_pth):
  517. torch.save(model_dict, model_name_pth)
  518. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  519. return model, args
  520. @classmethod
  521. def convert_tf2torch(
  522. cls,
  523. model,
  524. ckpt,
  525. ):
  526. logging.info("start convert tf model to torch model")
  527. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  528. var_dict_tf = load_tf_dict(ckpt)
  529. var_dict_torch = model.state_dict()
  530. var_dict_torch_update = dict()
  531. # speech encoder
  532. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  533. var_dict_torch_update.update(var_dict_torch_update_local)
  534. # speaker encoder
  535. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  536. var_dict_torch_update.update(var_dict_torch_update_local)
  537. # cd scorer
  538. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  539. var_dict_torch_update.update(var_dict_torch_update_local)
  540. # ci scorer
  541. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  542. var_dict_torch_update.update(var_dict_torch_update_local)
  543. # decoder
  544. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  545. var_dict_torch_update.update(var_dict_torch_update_local)
  546. return var_dict_torch_update