vad.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. import argparse
  2. import logging
  3. import os
  4. from pathlib import Path
  5. from typing import Callable
  6. from typing import Collection
  7. from typing import Dict
  8. from typing import List
  9. from typing import Optional
  10. from typing import Tuple
  11. from typing import Union
  12. import numpy as np
  13. import torch
  14. import yaml
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.datasets.collate_fn import CommonCollateFn
  18. from funasr.layers.abs_normalize import AbsNormalize
  19. from funasr.layers.global_mvn import GlobalMVN
  20. from funasr.layers.utterance_mvn import UtteranceMVN
  21. from funasr.models.e2e_vad import E2EVadModel
  22. from funasr.models.encoder.fsmn_encoder import FSMN
  23. from funasr.models.frontend.abs_frontend import AbsFrontend
  24. from funasr.models.frontend.default import DefaultFrontend
  25. from funasr.models.frontend.fused import FusedFrontends
  26. from funasr.models.frontend.s3prl import S3prlFrontend
  27. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  28. from funasr.models.frontend.windowing import SlidingWindow
  29. from funasr.models.specaug.abs_specaug import AbsSpecAug
  30. from funasr.models.specaug.specaug import SpecAug
  31. from funasr.models.specaug.specaug import SpecAugLFR
  32. from funasr.tasks.abs_task import AbsTask
  33. from funasr.train.class_choices import ClassChoices
  34. from funasr.train.trainer import Trainer
  35. from funasr.utils.types import float_or_none
  36. from funasr.utils.types import int_or_none
  37. from funasr.utils.types import str_or_none
  38. frontend_choices = ClassChoices(
  39. name="frontend",
  40. classes=dict(
  41. default=DefaultFrontend,
  42. sliding_window=SlidingWindow,
  43. s3prl=S3prlFrontend,
  44. fused=FusedFrontends,
  45. wav_frontend=WavFrontend,
  46. wav_frontend_online=WavFrontendOnline,
  47. ),
  48. type_check=AbsFrontend,
  49. default="default",
  50. )
  51. specaug_choices = ClassChoices(
  52. name="specaug",
  53. classes=dict(
  54. specaug=SpecAug,
  55. specaug_lfr=SpecAugLFR,
  56. ),
  57. type_check=AbsSpecAug,
  58. default=None,
  59. optional=True,
  60. )
  61. normalize_choices = ClassChoices(
  62. "normalize",
  63. classes=dict(
  64. global_mvn=GlobalMVN,
  65. utterance_mvn=UtteranceMVN,
  66. ),
  67. type_check=AbsNormalize,
  68. default=None,
  69. optional=True,
  70. )
  71. model_choices = ClassChoices(
  72. "model",
  73. classes=dict(
  74. e2evad=E2EVadModel,
  75. ),
  76. type_check=object,
  77. default="e2evad",
  78. )
  79. encoder_choices = ClassChoices(
  80. "encoder",
  81. classes=dict(
  82. fsmn=FSMN,
  83. ),
  84. type_check=torch.nn.Module,
  85. default="fsmn",
  86. )
  87. class VADTask(AbsTask):
  88. # If you need more than one optimizers, change this value
  89. num_optimizers: int = 1
  90. # Add variable objects configurations
  91. class_choices_list = [
  92. # --frontend and --frontend_conf
  93. frontend_choices,
  94. # --model and --model_conf
  95. model_choices,
  96. ]
  97. # If you need to modify train() or eval() procedures, change Trainer class here
  98. trainer = Trainer
  99. @classmethod
  100. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  101. group = parser.add_argument_group(description="Task related")
  102. # NOTE(kamo): add_arguments(..., required=True) can't be used
  103. # to provide --print_config mode. Instead of it, do as
  104. # required = parser.get_default("required")
  105. # required += ["token_list"]
  106. group.add_argument(
  107. "--init",
  108. type=lambda x: str_or_none(x.lower()),
  109. default=None,
  110. help="The initialization method",
  111. choices=[
  112. "chainer",
  113. "xavier_uniform",
  114. "xavier_normal",
  115. "kaiming_uniform",
  116. "kaiming_normal",
  117. None,
  118. ],
  119. )
  120. group.add_argument(
  121. "--input_size",
  122. type=int_or_none,
  123. default=None,
  124. help="The number of input dimension of the feature",
  125. )
  126. group = parser.add_argument_group(description="Preprocess related")
  127. parser.add_argument(
  128. "--speech_volume_normalize",
  129. type=float_or_none,
  130. default=None,
  131. help="Scale the maximum amplitude to the given value.",
  132. )
  133. parser.add_argument(
  134. "--rir_scp",
  135. type=str_or_none,
  136. default=None,
  137. help="The file path of rir scp file.",
  138. )
  139. parser.add_argument(
  140. "--rir_apply_prob",
  141. type=float,
  142. default=1.0,
  143. help="THe probability for applying RIR convolution.",
  144. )
  145. parser.add_argument(
  146. "--cmvn_file",
  147. type=str_or_none,
  148. default=None,
  149. help="The file path of noise scp file.",
  150. )
  151. parser.add_argument(
  152. "--noise_scp",
  153. type=str_or_none,
  154. default=None,
  155. help="The file path of noise scp file.",
  156. )
  157. parser.add_argument(
  158. "--noise_apply_prob",
  159. type=float,
  160. default=1.0,
  161. help="The probability applying Noise adding.",
  162. )
  163. parser.add_argument(
  164. "--noise_db_range",
  165. type=str,
  166. default="13_15",
  167. help="The range of noise decibel level.",
  168. )
  169. for class_choices in cls.class_choices_list:
  170. # Append --<name> and --<name>_conf.
  171. # e.g. --encoder and --encoder_conf
  172. class_choices.add_arguments(group)
  173. @classmethod
  174. def build_collate_fn(
  175. cls, args: argparse.Namespace, train: bool
  176. ) -> Callable[
  177. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  178. Tuple[List[str], Dict[str, torch.Tensor]],
  179. ]:
  180. assert check_argument_types()
  181. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  182. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  183. @classmethod
  184. def build_preprocess_fn(
  185. cls, args: argparse.Namespace, train: bool
  186. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  187. assert check_argument_types()
  188. # if args.use_preprocessor:
  189. # retval = CommonPreprocessor(
  190. # train=train,
  191. # # NOTE(kamo): Check attribute existence for backward compatibility
  192. # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  193. # rir_apply_prob=args.rir_apply_prob
  194. # if hasattr(args, "rir_apply_prob")
  195. # else 1.0,
  196. # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  197. # noise_apply_prob=args.noise_apply_prob
  198. # if hasattr(args, "noise_apply_prob")
  199. # else 1.0,
  200. # noise_db_range=args.noise_db_range
  201. # if hasattr(args, "noise_db_range")
  202. # else "13_15",
  203. # speech_volume_normalize=args.speech_volume_normalize
  204. # if hasattr(args, "rir_scp")
  205. # else None,
  206. # )
  207. # else:
  208. # retval = None
  209. retval = None
  210. assert check_return_type(retval)
  211. return retval
  212. @classmethod
  213. def required_data_names(
  214. cls, train: bool = True, inference: bool = False
  215. ) -> Tuple[str, ...]:
  216. if not inference:
  217. retval = ("speech", "text")
  218. else:
  219. # Recognition mode
  220. retval = ("speech",)
  221. return retval
  222. @classmethod
  223. def optional_data_names(
  224. cls, train: bool = True, inference: bool = False
  225. ) -> Tuple[str, ...]:
  226. retval = ()
  227. assert check_return_type(retval)
  228. return retval
  229. @classmethod
  230. def build_model(cls, args: argparse.Namespace):
  231. assert check_argument_types()
  232. # 4. Encoder
  233. encoder_class = encoder_choices.get_class(args.encoder)
  234. encoder = encoder_class(**args.encoder_conf)
  235. # 5. Build model
  236. try:
  237. model_class = model_choices.get_class(args.model)
  238. except AttributeError:
  239. model_class = model_choices.get_class("e2evad")
  240. # 1. frontend
  241. if args.input_size is None:
  242. # Extract features in the model
  243. frontend_class = frontend_choices.get_class(args.frontend)
  244. if args.frontend == 'wav_frontend':
  245. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  246. else:
  247. frontend = frontend_class(**args.frontend_conf)
  248. input_size = frontend.output_size()
  249. else:
  250. # Give features from data-loader
  251. args.frontend = None
  252. args.frontend_conf = {}
  253. frontend = None
  254. input_size = args.input_size
  255. model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
  256. return model
  257. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  258. @classmethod
  259. def build_model_from_file(
  260. cls,
  261. config_file: Union[Path, str] = None,
  262. model_file: Union[Path, str] = None,
  263. device: str = "cpu",
  264. cmvn_file: Union[Path, str] = None,
  265. ):
  266. """Build model from the files.
  267. This method is used for inference or fine-tuning.
  268. Args:
  269. config_file: The yaml file saved when training.
  270. model_file: The model file saved when training.
  271. device: Device type, "cpu", "cuda", or "cuda:N".
  272. """
  273. assert check_argument_types()
  274. if config_file is None:
  275. assert model_file is not None, (
  276. "The argument 'model_file' must be provided "
  277. "if the argument 'config_file' is not specified."
  278. )
  279. config_file = Path(model_file).parent / "config.yaml"
  280. else:
  281. config_file = Path(config_file)
  282. with config_file.open("r", encoding="utf-8") as f:
  283. args = yaml.safe_load(f)
  284. # if cmvn_file is not None:
  285. args["cmvn_file"] = cmvn_file
  286. args = argparse.Namespace(**args)
  287. model = cls.build_model(args)
  288. model.to(device)
  289. model_dict = dict()
  290. model_name_pth = None
  291. if model_file is not None:
  292. logging.info("model_file is {}".format(model_file))
  293. if device == "cuda":
  294. device = f"cuda:{torch.cuda.current_device()}"
  295. model_dir = os.path.dirname(model_file)
  296. model_name = os.path.basename(model_file)
  297. model_dict = torch.load(model_file, map_location=device)
  298. model.encoder.load_state_dict(model_dict)
  299. return model, args