diar.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from funasr.datasets.collate_fn import DiarCollateFn
  16. from funasr.datasets.preprocessor import CommonPreprocessor
  17. from funasr.layers.abs_normalize import AbsNormalize
  18. from funasr.layers.global_mvn import GlobalMVN
  19. from funasr.layers.utterance_mvn import UtteranceMVN
  20. from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
  21. from funasr.models.ctc import CTC
  22. from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
  23. from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
  24. from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
  25. from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
  26. from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
  27. from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
  28. from funasr.models.e2e_diar_sond import DiarSondModel
  29. from funasr.models.encoder.abs_encoder import AbsEncoder
  30. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  31. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  32. from funasr.models.encoder.rnn_encoder import RNNEncoder
  33. from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
  34. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  35. from funasr.models.frontend.abs_frontend import AbsFrontend
  36. from funasr.models.frontend.default import DefaultFrontend
  37. from funasr.models.frontend.fused import FusedFrontends
  38. from funasr.models.frontend.s3prl import S3prlFrontend
  39. from funasr.models.frontend.wav_frontend import WavFrontend
  40. from funasr.models.frontend.windowing import SlidingWindow
  41. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  42. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  43. HuggingFaceTransformersPostEncoder, # noqa: H301
  44. )
  45. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  46. from funasr.models.preencoder.linear import LinearProjection
  47. from funasr.models.preencoder.sinc import LightweightSincConvs
  48. from funasr.models.specaug.abs_specaug import AbsSpecAug
  49. from funasr.models.specaug.specaug import SpecAug
  50. from funasr.models.specaug.specaug import SpecAugLFR
  51. from funasr.models.specaug.abs_profileaug import AbsProfileAug
  52. from funasr.models.specaug.profileaug import ProfileAug
  53. from funasr.tasks.abs_task import AbsTask
  54. from funasr.torch_utils.initialize import initialize
  55. from funasr.train.class_choices import ClassChoices
  56. from funasr.train.trainer import Trainer
  57. from funasr.utils.types import float_or_none
  58. from funasr.utils.types import int_or_none
  59. from funasr.utils.types import str2bool
  60. from funasr.utils.types import str_or_none
  61. frontend_choices = ClassChoices(
  62. name="frontend",
  63. classes=dict(
  64. default=DefaultFrontend,
  65. sliding_window=SlidingWindow,
  66. s3prl=S3prlFrontend,
  67. fused=FusedFrontends,
  68. wav_frontend=WavFrontend,
  69. ),
  70. type_check=AbsFrontend,
  71. default="default",
  72. )
  73. specaug_choices = ClassChoices(
  74. name="specaug",
  75. classes=dict(
  76. specaug=SpecAug,
  77. specaug_lfr=SpecAugLFR,
  78. ),
  79. type_check=AbsSpecAug,
  80. default=None,
  81. optional=True,
  82. )
  83. profileaug_choices = ClassChoices(
  84. name="profileaug",
  85. classes=dict(
  86. profileaug=ProfileAug,
  87. ),
  88. type_check=AbsProfileAug,
  89. default=None,
  90. optional=True,
  91. )
  92. normalize_choices = ClassChoices(
  93. "normalize",
  94. classes=dict(
  95. global_mvn=GlobalMVN,
  96. utterance_mvn=UtteranceMVN,
  97. ),
  98. type_check=AbsNormalize,
  99. default=None,
  100. optional=True,
  101. )
  102. label_aggregator_choices = ClassChoices(
  103. "label_aggregator",
  104. classes=dict(
  105. label_aggregator=LabelAggregate,
  106. label_aggregator_max_pool=LabelAggregateMaxPooling,
  107. ),
  108. type_check=torch.nn.Module,
  109. default=None,
  110. optional=True,
  111. )
  112. model_choices = ClassChoices(
  113. "model",
  114. classes=dict(
  115. sond=DiarSondModel,
  116. ),
  117. type_check=torch.nn.Module,
  118. default="sond",
  119. )
  120. encoder_choices = ClassChoices(
  121. "encoder",
  122. classes=dict(
  123. conformer=ConformerEncoder,
  124. transformer=TransformerEncoder,
  125. rnn=RNNEncoder,
  126. sanm=SANMEncoder,
  127. san=SelfAttentionEncoder,
  128. fsmn=FsmnEncoder,
  129. conv=ConvEncoder,
  130. resnet34=ResNet34Diar,
  131. resnet34_sp_l2reg=ResNet34SpL2RegDiar,
  132. sanm_chunk_opt=SANMEncoderChunkOpt,
  133. data2vec_encoder=Data2VecEncoder,
  134. ecapa_tdnn=ECAPA_TDNN,
  135. ),
  136. type_check=torch.nn.Module,
  137. default="resnet34",
  138. )
  139. speaker_encoder_choices = ClassChoices(
  140. "speaker_encoder",
  141. classes=dict(
  142. conformer=ConformerEncoder,
  143. transformer=TransformerEncoder,
  144. rnn=RNNEncoder,
  145. sanm=SANMEncoder,
  146. san=SelfAttentionEncoder,
  147. fsmn=FsmnEncoder,
  148. conv=ConvEncoder,
  149. sanm_chunk_opt=SANMEncoderChunkOpt,
  150. data2vec_encoder=Data2VecEncoder,
  151. ),
  152. type_check=AbsEncoder,
  153. default=None,
  154. optional=True
  155. )
  156. cd_scorer_choices = ClassChoices(
  157. "cd_scorer",
  158. classes=dict(
  159. san=SelfAttentionEncoder,
  160. ),
  161. type_check=AbsEncoder,
  162. default=None,
  163. optional=True,
  164. )
  165. ci_scorer_choices = ClassChoices(
  166. "ci_scorer",
  167. classes=dict(
  168. dot=DotScorer,
  169. cosine=CosScorer,
  170. conv=ConvEncoder,
  171. ),
  172. type_check=torch.nn.Module,
  173. default=None,
  174. optional=True,
  175. )
  176. # decoder is used for output (e.g. post_net in SOND)
  177. decoder_choices = ClassChoices(
  178. "decoder",
  179. classes=dict(
  180. rnn=RNNEncoder,
  181. fsmn=FsmnEncoder,
  182. ),
  183. type_check=torch.nn.Module,
  184. default="fsmn",
  185. )
  186. class DiarTask(AbsTask):
  187. # If you need more than 1 optimizer, change this value
  188. num_optimizers: int = 1
  189. # Add variable objects configurations
  190. class_choices_list = [
  191. # --frontend and --frontend_conf
  192. frontend_choices,
  193. # --specaug and --specaug_conf
  194. specaug_choices,
  195. # --profileaug and --profileaug_conf
  196. profileaug_choices,
  197. # --normalize and --normalize_conf
  198. normalize_choices,
  199. # --label_aggregator and --label_aggregator_conf
  200. label_aggregator_choices,
  201. # --model and --model_conf
  202. model_choices,
  203. # --encoder and --encoder_conf
  204. encoder_choices,
  205. # --speaker_encoder and --speaker_encoder_conf
  206. speaker_encoder_choices,
  207. # --cd_scorer and cd_scorer_conf
  208. cd_scorer_choices,
  209. # --ci_scorer and ci_scorer_conf
  210. ci_scorer_choices,
  211. # --decoder and --decoder_conf
  212. decoder_choices,
  213. ]
  214. # If you need to modify train() or eval() procedures, change Trainer class here
  215. trainer = Trainer
  216. @classmethod
  217. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  218. group = parser.add_argument_group(description="Task related")
  219. # NOTE(kamo): add_arguments(..., required=True) can't be used
  220. # to provide --print_config mode. Instead of it, do as
  221. # required = parser.get_default("required")
  222. # required += ["token_list"]
  223. group.add_argument(
  224. "--token_list",
  225. type=str_or_none,
  226. default=None,
  227. help="A text mapping int-id to token",
  228. )
  229. group.add_argument(
  230. "--split_with_space",
  231. type=str2bool,
  232. default=True,
  233. help="whether to split text using <space>",
  234. )
  235. group.add_argument(
  236. "--seg_dict_file",
  237. type=str,
  238. default=None,
  239. help="seg_dict_file for text processing",
  240. )
  241. group.add_argument(
  242. "--init",
  243. type=lambda x: str_or_none(x.lower()),
  244. default=None,
  245. help="The initialization method",
  246. choices=[
  247. "chainer",
  248. "xavier_uniform",
  249. "xavier_normal",
  250. "kaiming_uniform",
  251. "kaiming_normal",
  252. None,
  253. ],
  254. )
  255. group.add_argument(
  256. "--input_size",
  257. type=int_or_none,
  258. default=None,
  259. help="The number of input dimension of the feature",
  260. )
  261. group = parser.add_argument_group(description="Preprocess related")
  262. group.add_argument(
  263. "--use_preprocessor",
  264. type=str2bool,
  265. default=True,
  266. help="Apply preprocessing to data or not",
  267. )
  268. group.add_argument(
  269. "--token_type",
  270. type=str,
  271. default="char",
  272. choices=["char"],
  273. help="The text will be tokenized in the specified level token",
  274. )
  275. parser.add_argument(
  276. "--speech_volume_normalize",
  277. type=float_or_none,
  278. default=None,
  279. help="Scale the maximum amplitude to the given value.",
  280. )
  281. parser.add_argument(
  282. "--rir_scp",
  283. type=str_or_none,
  284. default=None,
  285. help="The file path of rir scp file.",
  286. )
  287. parser.add_argument(
  288. "--rir_apply_prob",
  289. type=float,
  290. default=1.0,
  291. help="THe probability for applying RIR convolution.",
  292. )
  293. parser.add_argument(
  294. "--cmvn_file",
  295. type=str_or_none,
  296. default=None,
  297. help="The file path of noise scp file.",
  298. )
  299. parser.add_argument(
  300. "--noise_scp",
  301. type=str_or_none,
  302. default=None,
  303. help="The file path of noise scp file.",
  304. )
  305. parser.add_argument(
  306. "--noise_apply_prob",
  307. type=float,
  308. default=1.0,
  309. help="The probability applying Noise adding.",
  310. )
  311. parser.add_argument(
  312. "--noise_db_range",
  313. type=str,
  314. default="13_15",
  315. help="The range of noise decibel level.",
  316. )
  317. for class_choices in cls.class_choices_list:
  318. # Append --<name> and --<name>_conf.
  319. # e.g. --encoder and --encoder_conf
  320. class_choices.add_arguments(group)
  321. @classmethod
  322. def build_collate_fn(
  323. cls, args: argparse.Namespace, train: bool
  324. ) -> Callable[
  325. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  326. Tuple[List[str], Dict[str, torch.Tensor]],
  327. ]:
  328. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  329. return DiarCollateFn(float_pad_value=0.0, int_pad_value=-1)
  330. @classmethod
  331. def build_preprocess_fn(
  332. cls, args: argparse.Namespace, train: bool
  333. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  334. if args.use_preprocessor:
  335. retval = CommonPreprocessor(
  336. train=train,
  337. token_type=args.token_type,
  338. token_list=args.token_list,
  339. bpemodel=None,
  340. non_linguistic_symbols=None,
  341. text_cleaner=None,
  342. g2p_type=None,
  343. split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
  344. seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
  345. # NOTE(kamo): Check attribute existence for backward compatibility
  346. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  347. rir_apply_prob=args.rir_apply_prob
  348. if hasattr(args, "rir_apply_prob")
  349. else 1.0,
  350. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  351. noise_apply_prob=args.noise_apply_prob
  352. if hasattr(args, "noise_apply_prob")
  353. else 1.0,
  354. noise_db_range=args.noise_db_range
  355. if hasattr(args, "noise_db_range")
  356. else "13_15",
  357. speech_volume_normalize=args.speech_volume_normalize
  358. if hasattr(args, "rir_scp")
  359. else None,
  360. )
  361. else:
  362. retval = None
  363. return retval
  364. @classmethod
  365. def required_data_names(
  366. cls, train: bool = True, inference: bool = False
  367. ) -> Tuple[str, ...]:
  368. if not inference:
  369. retval = ("speech", "profile", "binary_labels")
  370. else:
  371. # Recognition mode
  372. retval = ("speech", "profile")
  373. return retval
  374. @classmethod
  375. def optional_data_names(
  376. cls, train: bool = True, inference: bool = False
  377. ) -> Tuple[str, ...]:
  378. retval = ()
  379. return retval
  380. @classmethod
  381. def build_optimizers(
  382. cls,
  383. args: argparse.Namespace,
  384. model: torch.nn.Module,
  385. ) -> List[torch.optim.Optimizer]:
  386. if cls.num_optimizers != 1:
  387. raise RuntimeError(
  388. "build_optimizers() must be overridden if num_optimizers != 1"
  389. )
  390. from funasr.tasks.abs_task import optim_classes
  391. optim_class = optim_classes.get(args.optim)
  392. if optim_class is None:
  393. raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
  394. else:
  395. if (hasattr(model, "model_regularizer_weight") and
  396. model.model_regularizer_weight > 0.0 and
  397. hasattr(model, "get_regularize_parameters")
  398. ):
  399. to_regularize_parameters, normal_parameters = model.get_regularize_parameters()
  400. logging.info(f"Set weight decay {model.model_regularizer_weight} for parameters: "
  401. f"{[name for name, value in to_regularize_parameters]}")
  402. module_optim_config = [
  403. {"params": [value for name, value in to_regularize_parameters],
  404. "weight_decay": model.model_regularizer_weight},
  405. {"params": [value for name, value in normal_parameters],
  406. "weight_decay": 0.0}
  407. ]
  408. optim = optim_class(module_optim_config, **args.optim_conf)
  409. else:
  410. optim = optim_class(model.parameters(), **args.optim_conf)
  411. optimizers = [optim]
  412. return optimizers
  413. @classmethod
  414. def build_model(cls, args: argparse.Namespace):
  415. if isinstance(args.token_list, str):
  416. with open(args.token_list, encoding="utf-8") as f:
  417. token_list = [line.rstrip() for line in f]
  418. # Overwriting token_list to keep it as "portable".
  419. args.token_list = list(token_list)
  420. elif isinstance(args.token_list, (tuple, list)):
  421. token_list = list(args.token_list)
  422. else:
  423. raise RuntimeError("token_list must be str or list")
  424. vocab_size = len(token_list)
  425. logging.info(f"Vocabulary size: {vocab_size}")
  426. # 1. frontend
  427. if args.input_size is None:
  428. # Extract features in the model
  429. frontend_class = frontend_choices.get_class(args.frontend)
  430. if args.frontend == 'wav_frontend':
  431. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  432. else:
  433. frontend = frontend_class(**args.frontend_conf)
  434. input_size = frontend.output_size()
  435. else:
  436. # Give features from data-loader
  437. args.frontend = None
  438. args.frontend_conf = {}
  439. frontend = None
  440. input_size = args.input_size
  441. # 2. Data augmentation for spectrogram
  442. if args.specaug is not None:
  443. specaug_class = specaug_choices.get_class(args.specaug)
  444. specaug = specaug_class(**args.specaug_conf)
  445. else:
  446. specaug = None
  447. # 2b. Data augmentation for Profiles
  448. if hasattr(args, "profileaug") and args.profileaug is not None:
  449. profileaug_class = profileaug_choices.get_class(args.profileaug)
  450. profileaug = profileaug_class(**args.profileaug_conf)
  451. else:
  452. profileaug = None
  453. # 3. Normalization layer
  454. if args.normalize is not None:
  455. normalize_class = normalize_choices.get_class(args.normalize)
  456. normalize = normalize_class(**args.normalize_conf)
  457. else:
  458. normalize = None
  459. # 4. Encoder
  460. encoder_class = encoder_choices.get_class(args.encoder)
  461. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  462. # 5. speaker encoder
  463. if getattr(args, "speaker_encoder", None) is not None:
  464. speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
  465. speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
  466. else:
  467. speaker_encoder = None
  468. # 6. CI & CD scorer
  469. if getattr(args, "ci_scorer", None) is not None:
  470. ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
  471. ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
  472. else:
  473. ci_scorer = None
  474. if getattr(args, "cd_scorer", None) is not None:
  475. cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
  476. cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
  477. else:
  478. cd_scorer = None
  479. # 7. Decoder
  480. decoder_class = decoder_choices.get_class(args.decoder)
  481. decoder = decoder_class(**args.decoder_conf)
  482. if getattr(args, "label_aggregator", None) is not None:
  483. label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
  484. label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
  485. else:
  486. label_aggregator = None
  487. # 9. Build model
  488. model_class = model_choices.get_class(args.model)
  489. model = model_class(
  490. vocab_size=vocab_size,
  491. frontend=frontend,
  492. specaug=specaug,
  493. profileaug=profileaug,
  494. normalize=normalize,
  495. label_aggregator=label_aggregator,
  496. encoder=encoder,
  497. speaker_encoder=speaker_encoder,
  498. ci_scorer=ci_scorer,
  499. cd_scorer=cd_scorer,
  500. decoder=decoder,
  501. token_list=token_list,
  502. **args.model_conf,
  503. )
  504. # 10. Initialize
  505. if args.init is not None:
  506. initialize(model, args.init)
  507. logging.info(f"Init model parameters with {args.init}.")
  508. return model
  509. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  510. @classmethod
  511. def build_model_from_file(
  512. cls,
  513. config_file: Union[Path, str] = None,
  514. model_file: Union[Path, str] = None,
  515. cmvn_file: Union[Path, str] = None,
  516. device: Union[str, torch.device] = "cpu",
  517. ):
  518. """Build model from the files.
  519. This method is used for inference or fine-tuning.
  520. Args:
  521. config_file: The yaml file saved when training.
  522. model_file: The model file saved when training.
  523. cmvn_file: The cmvn file for front-end
  524. device: Device type, "cpu", "cuda", or "cuda:N".
  525. """
  526. if config_file is None:
  527. assert model_file is not None, (
  528. "The argument 'model_file' must be provided "
  529. "if the argument 'config_file' is not specified."
  530. )
  531. config_file = Path(model_file).parent / "config.yaml"
  532. else:
  533. config_file = Path(config_file)
  534. with config_file.open("r", encoding="utf-8") as f:
  535. args = yaml.safe_load(f)
  536. if cmvn_file is not None:
  537. args["cmvn_file"] = cmvn_file
  538. args = argparse.Namespace(**args)
  539. model = cls.build_model(args)
  540. if not isinstance(model, torch.nn.Module):
  541. raise RuntimeError(
  542. f"model must inherit {torch.nn.Module.__name__}, but got {type(model)}"
  543. )
  544. model.to(device)
  545. model_dict = dict()
  546. model_name_pth = None
  547. if model_file is not None:
  548. logging.info("model_file is {}".format(model_file))
  549. if device == "cuda":
  550. device = f"cuda:{torch.cuda.current_device()}"
  551. model_dir = os.path.dirname(model_file)
  552. model_name = os.path.basename(model_file)
  553. if "model.ckpt-" in model_name or ".bin" in model_name:
  554. if ".bin" in model_name:
  555. model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
  556. else:
  557. model_name_pth = os.path.join(model_dir, "{}.pth".format(model_name))
  558. if os.path.exists(model_name_pth):
  559. logging.info("model_file is load from pth: {}".format(model_name_pth))
  560. model_dict = torch.load(model_name_pth, map_location=device)
  561. else:
  562. model_dict = cls.convert_tf2torch(model, model_file)
  563. # model.load_state_dict(model_dict)
  564. else:
  565. model_dict = torch.load(model_file, map_location=device)
  566. model_dict = cls.fileter_model_dict(model_dict, model.state_dict())
  567. model.load_state_dict(model_dict)
  568. if model_name_pth is not None and not os.path.exists(model_name_pth):
  569. torch.save(model_dict, model_name_pth)
  570. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  571. return model, args
  572. @classmethod
  573. def fileter_model_dict(cls, src_dict: dict, dest_dict: dict):
  574. from collections import OrderedDict
  575. new_dict = OrderedDict()
  576. for key, value in src_dict.items():
  577. if key in dest_dict:
  578. new_dict[key] = value
  579. else:
  580. logging.info("{} is no longer needed in this model.".format(key))
  581. for key, value in dest_dict.items():
  582. if key not in new_dict:
  583. logging.warning("{} is missed in checkpoint.".format(key))
  584. return new_dict
  585. @classmethod
  586. def convert_tf2torch(
  587. cls,
  588. model,
  589. ckpt,
  590. ):
  591. logging.info("start convert tf model to torch model")
  592. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  593. var_dict_tf = load_tf_dict(ckpt)
  594. var_dict_torch = model.state_dict()
  595. var_dict_torch_update = dict()
  596. # speech encoder
  597. if model.encoder is not None:
  598. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  599. var_dict_torch_update.update(var_dict_torch_update_local)
  600. # speaker encoder
  601. if model.speaker_encoder is not None:
  602. var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  603. var_dict_torch_update.update(var_dict_torch_update_local)
  604. # cd scorer
  605. if model.cd_scorer is not None:
  606. var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  607. var_dict_torch_update.update(var_dict_torch_update_local)
  608. # ci scorer
  609. if model.ci_scorer is not None:
  610. var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
  611. var_dict_torch_update.update(var_dict_torch_update_local)
  612. # decoder
  613. if model.decoder is not None:
  614. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  615. var_dict_torch_update.update(var_dict_torch_update_local)
  616. return var_dict_torch_update