sv.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.datasets.preprocessor import CommonPreprocessor
  19. from funasr.layers.abs_normalize import AbsNormalize
  20. from funasr.layers.global_mvn import GlobalMVN
  21. from funasr.layers.utterance_mvn import UtteranceMVN
  22. from funasr.models.e2e_asr import ESPnetASRModel
  23. from funasr.models.decoder.abs_decoder import AbsDecoder
  24. from funasr.models.encoder.abs_encoder import AbsEncoder
  25. from funasr.models.encoder.rnn_encoder import RNNEncoder
  26. from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
  27. from funasr.models.pooling.statistic_pooling import StatisticPooling
  28. from funasr.models.decoder.sv_decoder import DenseDecoder
  29. from funasr.models.e2e_sv import ESPnetSVModel
  30. from funasr.models.frontend.abs_frontend import AbsFrontend
  31. from funasr.models.frontend.default import DefaultFrontend
  32. from funasr.models.frontend.fused import FusedFrontends
  33. from funasr.models.frontend.s3prl import S3prlFrontend
  34. from funasr.models.frontend.windowing import SlidingWindow
  35. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  36. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  37. HuggingFaceTransformersPostEncoder, # noqa: H301
  38. )
  39. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  40. from funasr.models.preencoder.linear import LinearProjection
  41. from funasr.models.preencoder.sinc import LightweightSincConvs
  42. from funasr.models.specaug.abs_specaug import AbsSpecAug
  43. from funasr.models.specaug.specaug import SpecAug
  44. from funasr.tasks.abs_task import AbsTask
  45. from funasr.torch_utils.initialize import initialize
  46. from funasr.models.base_model import FunASRModel
  47. from funasr.train.class_choices import ClassChoices
  48. from funasr.train.trainer import Trainer
  49. from funasr.utils.types import float_or_none
  50. from funasr.utils.types import int_or_none
  51. from funasr.utils.types import str2bool
  52. from funasr.utils.types import str_or_none
  53. from funasr.models.frontend.wav_frontend import WavFrontend
  54. frontend_choices = ClassChoices(
  55. name="frontend",
  56. classes=dict(
  57. default=DefaultFrontend,
  58. sliding_window=SlidingWindow,
  59. s3prl=S3prlFrontend,
  60. fused=FusedFrontends,
  61. wav_frontend=WavFrontend,
  62. ),
  63. type_check=AbsFrontend,
  64. default="default",
  65. )
  66. specaug_choices = ClassChoices(
  67. name="specaug",
  68. classes=dict(
  69. specaug=SpecAug,
  70. ),
  71. type_check=AbsSpecAug,
  72. default=None,
  73. optional=True,
  74. )
  75. normalize_choices = ClassChoices(
  76. "normalize",
  77. classes=dict(
  78. global_mvn=GlobalMVN,
  79. utterance_mvn=UtteranceMVN,
  80. ),
  81. type_check=AbsNormalize,
  82. default=None,
  83. optional=True,
  84. )
  85. model_choices = ClassChoices(
  86. "model",
  87. classes=dict(
  88. espnet=ESPnetSVModel,
  89. ),
  90. type_check=FunASRModel,
  91. default="espnet",
  92. )
  93. preencoder_choices = ClassChoices(
  94. name="preencoder",
  95. classes=dict(
  96. sinc=LightweightSincConvs,
  97. linear=LinearProjection,
  98. ),
  99. type_check=AbsPreEncoder,
  100. default=None,
  101. optional=True,
  102. )
  103. encoder_choices = ClassChoices(
  104. "encoder",
  105. classes=dict(
  106. resnet34=ResNet34,
  107. resnet34_sp_l2reg=ResNet34_SP_L2Reg,
  108. rnn=RNNEncoder,
  109. ),
  110. type_check=AbsEncoder,
  111. default="resnet34",
  112. )
  113. postencoder_choices = ClassChoices(
  114. name="postencoder",
  115. classes=dict(
  116. hugging_face_transformers=HuggingFaceTransformersPostEncoder,
  117. ),
  118. type_check=AbsPostEncoder,
  119. default=None,
  120. optional=True,
  121. )
  122. pooling_choices = ClassChoices(
  123. name="pooling_type",
  124. classes=dict(
  125. statistic=StatisticPooling,
  126. ),
  127. type_check=torch.nn.Module,
  128. default="statistic",
  129. )
  130. decoder_choices = ClassChoices(
  131. "decoder",
  132. classes=dict(
  133. dense=DenseDecoder,
  134. ),
  135. type_check=AbsDecoder,
  136. default="dense",
  137. )
  138. class SVTask(AbsTask):
  139. # If you need more than one optimizers, change this value
  140. num_optimizers: int = 1
  141. # Add variable objects configurations
  142. class_choices_list = [
  143. # --frontend and --frontend_conf
  144. frontend_choices,
  145. # --specaug and --specaug_conf
  146. specaug_choices,
  147. # --normalize and --normalize_conf
  148. normalize_choices,
  149. # --model and --model_conf
  150. model_choices,
  151. # --preencoder and --preencoder_conf
  152. preencoder_choices,
  153. # --encoder and --encoder_conf
  154. encoder_choices,
  155. # --postencoder and --postencoder_conf
  156. postencoder_choices,
  157. # --pooling and --pooling_conf
  158. pooling_choices,
  159. # --decoder and --decoder_conf
  160. decoder_choices,
  161. ]
  162. # If you need to modify train() or eval() procedures, change Trainer class here
  163. trainer = Trainer
  164. @classmethod
  165. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  166. group = parser.add_argument_group(description="Task related")
  167. # NOTE(kamo): add_arguments(..., required=True) can't be used
  168. # to provide --print_config mode. Instead of it, do as
  169. required = parser.get_default("required")
  170. required += ["token_list"]
  171. group.add_argument(
  172. "--token_list",
  173. type=str_or_none,
  174. default=None,
  175. help="A text mapping int-id to speaker name",
  176. )
  177. group.add_argument(
  178. "--init",
  179. type=lambda x: str_or_none(x.lower()),
  180. default=None,
  181. help="The initialization method",
  182. choices=[
  183. "chainer",
  184. "xavier_uniform",
  185. "xavier_normal",
  186. "kaiming_uniform",
  187. "kaiming_normal",
  188. None,
  189. ],
  190. )
  191. group.add_argument(
  192. "--input_size",
  193. type=int_or_none,
  194. default=None,
  195. help="The number of input dimension of the feature",
  196. )
  197. group = parser.add_argument_group(description="Preprocess related")
  198. group.add_argument(
  199. "--use_preprocessor",
  200. type=str2bool,
  201. default=True,
  202. help="Apply preprocessing to data or not",
  203. )
  204. parser.add_argument(
  205. "--cleaner",
  206. type=str_or_none,
  207. choices=[None, "tacotron", "jaconv", "vietnamese"],
  208. default=None,
  209. help="Apply text cleaning",
  210. )
  211. parser.add_argument(
  212. "--speech_volume_normalize",
  213. type=float_or_none,
  214. default=None,
  215. help="Scale the maximum amplitude to the given value.",
  216. )
  217. parser.add_argument(
  218. "--rir_scp",
  219. type=str_or_none,
  220. default=None,
  221. help="The file path of rir scp file.",
  222. )
  223. parser.add_argument(
  224. "--rir_apply_prob",
  225. type=float,
  226. default=1.0,
  227. help="THe probability for applying RIR convolution.",
  228. )
  229. parser.add_argument(
  230. "--noise_scp",
  231. type=str_or_none,
  232. default=None,
  233. help="The file path of noise scp file.",
  234. )
  235. parser.add_argument(
  236. "--noise_apply_prob",
  237. type=float,
  238. default=1.0,
  239. help="The probability applying Noise adding.",
  240. )
  241. parser.add_argument(
  242. "--noise_db_range",
  243. type=str,
  244. default="13_15",
  245. help="The range of noise decibel level.",
  246. )
  247. for class_choices in cls.class_choices_list:
  248. # Append --<name> and --<name>_conf.
  249. # e.g. --encoder and --encoder_conf
  250. class_choices.add_arguments(group)
  251. @classmethod
  252. def build_collate_fn(
  253. cls, args: argparse.Namespace, train: bool
  254. ) -> Callable[
  255. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  256. Tuple[List[str], Dict[str, torch.Tensor]],
  257. ]:
  258. assert check_argument_types()
  259. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  260. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  261. @classmethod
  262. def build_preprocess_fn(
  263. cls, args: argparse.Namespace, train: bool
  264. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  265. assert check_argument_types()
  266. if args.use_preprocessor:
  267. retval = CommonPreprocessor(
  268. train=train,
  269. token_type=None,
  270. token_list=None,
  271. bpemodel=None,
  272. non_linguistic_symbols=None,
  273. text_cleaner=args.cleaner,
  274. g2p_type=None,
  275. # NOTE(kamo): Check attribute existence for backward compatibility
  276. rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  277. rir_apply_prob=args.rir_apply_prob
  278. if hasattr(args, "rir_apply_prob")
  279. else 1.0,
  280. noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  281. noise_apply_prob=args.noise_apply_prob
  282. if hasattr(args, "noise_apply_prob")
  283. else 1.0,
  284. noise_db_range=args.noise_db_range
  285. if hasattr(args, "noise_db_range")
  286. else "13_15",
  287. speech_volume_normalize=args.speech_volume_normalize
  288. if hasattr(args, "rir_scp")
  289. else None,
  290. )
  291. else:
  292. retval = None
  293. assert check_return_type(retval)
  294. return retval
  295. @classmethod
  296. def required_data_names(
  297. cls, train: bool = True, inference: bool = False
  298. ) -> Tuple[str, ...]:
  299. if not inference:
  300. retval = ("speech", "text")
  301. else:
  302. # Recognition mode
  303. retval = ("speech",)
  304. return retval
  305. @classmethod
  306. def optional_data_names(
  307. cls, train: bool = True, inference: bool = False
  308. ) -> Tuple[str, ...]:
  309. retval = ()
  310. if inference:
  311. retval = ("ref_speech",)
  312. assert check_return_type(retval)
  313. return retval
  314. @classmethod
  315. def build_model(cls, args: argparse.Namespace) -> ESPnetSVModel:
  316. assert check_argument_types()
  317. if isinstance(args.token_list, str):
  318. with open(args.token_list, encoding="utf-8") as f:
  319. token_list = [line.rstrip() for line in f]
  320. # Overwriting token_list to keep it as "portable".
  321. args.token_list = list(token_list)
  322. elif isinstance(args.token_list, (tuple, list)):
  323. token_list = list(args.token_list)
  324. else:
  325. raise RuntimeError("token_list must be str or list")
  326. vocab_size = len(token_list)
  327. logging.info(f"Speaker number: {vocab_size}")
  328. # 1. frontend
  329. if args.input_size is None:
  330. # Extract features in the model
  331. frontend_class = frontend_choices.get_class(args.frontend)
  332. frontend = frontend_class(**args.frontend_conf)
  333. input_size = frontend.output_size()
  334. else:
  335. # Give features from data-loader
  336. args.frontend = None
  337. args.frontend_conf = {}
  338. frontend = None
  339. input_size = args.input_size
  340. # 2. Data augmentation for spectrogram
  341. if args.specaug is not None:
  342. specaug_class = specaug_choices.get_class(args.specaug)
  343. specaug = specaug_class(**args.specaug_conf)
  344. else:
  345. specaug = None
  346. # 3. Normalization layer
  347. if args.normalize is not None:
  348. normalize_class = normalize_choices.get_class(args.normalize)
  349. normalize = normalize_class(**args.normalize_conf)
  350. else:
  351. normalize = None
  352. # 4. Pre-encoder input block
  353. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  354. if getattr(args, "preencoder", None) is not None:
  355. preencoder_class = preencoder_choices.get_class(args.preencoder)
  356. preencoder = preencoder_class(**args.preencoder_conf)
  357. input_size = preencoder.output_size()
  358. else:
  359. preencoder = None
  360. # 5. Encoder
  361. encoder_class = encoder_choices.get_class(args.encoder)
  362. encoder = encoder_class(input_size=input_size, **args.encoder_conf)
  363. # 6. Post-encoder block
  364. # NOTE(kan-bayashi): Use getattr to keep the compatibility
  365. encoder_output_size = encoder.output_size()
  366. if getattr(args, "postencoder", None) is not None:
  367. postencoder_class = postencoder_choices.get_class(args.postencoder)
  368. postencoder = postencoder_class(
  369. input_size=encoder_output_size, **args.postencoder_conf
  370. )
  371. encoder_output_size = postencoder.output_size()
  372. else:
  373. postencoder = None
  374. # 7. Pooling layer
  375. pooling_class = pooling_choices.get_class(args.pooling_type)
  376. pooling_dim = (2, 3)
  377. eps = 1e-12
  378. if hasattr(args, "pooling_type_conf"):
  379. if "pooling_dim" in args.pooling_type_conf:
  380. pooling_dim = args.pooling_type_conf["pooling_dim"]
  381. if "eps" in args.pooling_type_conf:
  382. eps = args.pooling_type_conf["eps"]
  383. pooling_layer = pooling_class(
  384. pooling_dim=pooling_dim,
  385. eps=eps,
  386. )
  387. if args.pooling_type == "statistic":
  388. encoder_output_size *= 2
  389. # 8. Decoder
  390. decoder_class = decoder_choices.get_class(args.decoder)
  391. decoder = decoder_class(
  392. vocab_size=vocab_size,
  393. encoder_output_size=encoder_output_size,
  394. **args.decoder_conf,
  395. )
  396. # 7. Build model
  397. try:
  398. model_class = model_choices.get_class(args.model)
  399. except AttributeError:
  400. model_class = model_choices.get_class("espnet")
  401. model = model_class(
  402. vocab_size=vocab_size,
  403. token_list=token_list,
  404. frontend=frontend,
  405. specaug=specaug,
  406. normalize=normalize,
  407. preencoder=preencoder,
  408. encoder=encoder,
  409. postencoder=postencoder,
  410. pooling_layer=pooling_layer,
  411. decoder=decoder,
  412. **args.model_conf,
  413. )
  414. # FIXME(kamo): Should be done in model?
  415. # 8. Initialize
  416. if args.init is not None:
  417. initialize(model, args.init)
  418. assert check_return_type(model)
  419. return model
  420. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  421. @classmethod
  422. def build_model_from_file(
  423. cls,
  424. config_file: Union[Path, str] = None,
  425. model_file: Union[Path, str] = None,
  426. cmvn_file: Union[Path, str] = None,
  427. device: str = "cpu",
  428. ):
  429. """Build model from the files.
  430. This method is used for inference or fine-tuning.
  431. Args:
  432. config_file: The yaml file saved when training.
  433. model_file: The model file saved when training.
  434. cmvn_file: The cmvn file for front-end
  435. device: Device type, "cpu", "cuda", or "cuda:N".
  436. """
  437. assert check_argument_types()
  438. if config_file is None:
  439. assert model_file is not None, (
  440. "The argument 'model_file' must be provided "
  441. "if the argument 'config_file' is not specified."
  442. )
  443. config_file = Path(model_file).parent / "config.yaml"
  444. else:
  445. config_file = Path(config_file)
  446. with config_file.open("r", encoding="utf-8") as f:
  447. args = yaml.safe_load(f)
  448. if cmvn_file is not None:
  449. args["cmvn_file"] = cmvn_file
  450. args = argparse.Namespace(**args)
  451. model = cls.build_model(args)
  452. if not isinstance(model, FunASRModel):
  453. raise RuntimeError(
  454. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  455. )
  456. model.to(device)
  457. model_dict = dict()
  458. model_name_pth = None
  459. if model_file is not None:
  460. logging.info("model_file is {}".format(model_file))
  461. if device == "cuda":
  462. device = f"cuda:{torch.cuda.current_device()}"
  463. model_dir = os.path.dirname(model_file)
  464. model_name = os.path.basename(model_file)
  465. if "model.ckpt-" in model_name or ".bin" in model_name:
  466. if ".bin" in model_name:
  467. model_name_pth = os.path.join(model_dir, model_name.replace('.bin', '.pb'))
  468. else:
  469. model_name_pth = os.path.join(model_dir, "{}.pb".format(model_name))
  470. if os.path.exists(model_name_pth):
  471. logging.info("model_file is load from pth: {}".format(model_name_pth))
  472. model_dict = torch.load(model_name_pth, map_location=device)
  473. else:
  474. model_dict = cls.convert_tf2torch(model, model_file)
  475. model.load_state_dict(model_dict)
  476. else:
  477. model_dict = torch.load(model_file, map_location=device)
  478. model.load_state_dict(model_dict)
  479. if model_name_pth is not None and not os.path.exists(model_name_pth):
  480. torch.save(model_dict, model_name_pth)
  481. logging.info("model_file is saved to pth: {}".format(model_name_pth))
  482. return model, args
  483. @classmethod
  484. def convert_tf2torch(
  485. cls,
  486. model,
  487. ckpt,
  488. ):
  489. logging.info("start convert tf model to torch model")
  490. from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
  491. var_dict_tf = load_tf_dict(ckpt)
  492. var_dict_torch = model.state_dict()
  493. var_dict_torch_update = dict()
  494. # speech encoder
  495. var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  496. var_dict_torch_update.update(var_dict_torch_update_local)
  497. # pooling layer
  498. var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
  499. var_dict_torch_update.update(var_dict_torch_update_local)
  500. # decoder
  501. var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
  502. var_dict_torch_update.update(var_dict_torch_update_local)
  503. return var_dict_torch_update