vad.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import argparse
  2. import logging
  3. from typing import Callable
  4. from typing import Collection
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. import os
  10. from pathlib import Path
  11. from typing import Tuple
  12. from typing import Union
  13. import yaml
  14. import numpy as np
  15. import torch
  16. from typeguard import check_argument_types
  17. from typeguard import check_return_type
  18. from funasr.datasets.collate_fn import CommonCollateFn
  19. from funasr.datasets.preprocessor import CommonPreprocessor
  20. from funasr.models.ctc import CTC
  21. from funasr.models.decoder.abs_decoder import AbsDecoder
  22. from funasr.models.decoder.rnn_decoder import RNNDecoder
  23. from funasr.models.decoder.transformer_decoder import (
  24. DynamicConvolution2DTransformerDecoder, # noqa: H301
  25. )
  26. from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
  27. from funasr.models.decoder.transformer_decoder import (
  28. LightweightConvolution2DTransformerDecoder, # noqa: H301
  29. )
  30. from funasr.models.decoder.transformer_decoder import (
  31. LightweightConvolutionTransformerDecoder, # noqa: H301
  32. )
  33. from funasr.models.decoder.transformer_decoder import TransformerDecoder
  34. from funasr.models.encoder.abs_encoder import AbsEncoder
  35. from funasr.models.encoder.conformer_encoder import ConformerEncoder
  36. from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
  37. from funasr.models.encoder.rnn_encoder import RNNEncoder
  38. from funasr.models.encoder.transformer_encoder import TransformerEncoder
  39. from funasr.models.frontend.abs_frontend import AbsFrontend
  40. from funasr.models.frontend.default import DefaultFrontend
  41. from funasr.models.frontend.fused import FusedFrontends
  42. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  43. from funasr.models.frontend.s3prl import S3prlFrontend
  44. from funasr.models.frontend.windowing import SlidingWindow
  45. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  46. from funasr.models.postencoder.hugging_face_transformers_postencoder import (
  47. HuggingFaceTransformersPostEncoder, # noqa: H301
  48. )
  49. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  50. from funasr.models.preencoder.linear import LinearProjection
  51. from funasr.models.preencoder.sinc import LightweightSincConvs
  52. from funasr.models.specaug.abs_specaug import AbsSpecAug
  53. from funasr.models.specaug.specaug import SpecAug
  54. from funasr.layers.abs_normalize import AbsNormalize
  55. from funasr.layers.global_mvn import GlobalMVN
  56. from funasr.layers.utterance_mvn import UtteranceMVN
  57. from funasr.tasks.abs_task import AbsTask
  58. from funasr.text.phoneme_tokenizer import g2p_choices
  59. from funasr.train.abs_espnet_model import AbsESPnetModel
  60. from funasr.train.class_choices import ClassChoices
  61. from funasr.train.trainer import Trainer
  62. from funasr.utils.get_default_kwargs import get_default_kwargs
  63. from funasr.utils.nested_dict_action import NestedDictAction
  64. from funasr.utils.types import float_or_none
  65. from funasr.utils.types import int_or_none
  66. from funasr.utils.types import str2bool
  67. from funasr.utils.types import str_or_none
  68. from funasr.models.specaug.specaug import SpecAugLFR
  69. from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
  70. from funasr.modules.subsampling import Conv1dSubsampling
  71. from funasr.models.e2e_vad import E2EVadModel
  72. from funasr.models.encoder.fsmn_encoder import FSMN
  73. frontend_choices = ClassChoices(
  74. name="frontend",
  75. classes=dict(
  76. default=DefaultFrontend,
  77. sliding_window=SlidingWindow,
  78. s3prl=S3prlFrontend,
  79. fused=FusedFrontends,
  80. wav_frontend=WavFrontend,
  81. wav_frontend_online=WavFrontendOnline,
  82. ),
  83. type_check=AbsFrontend,
  84. default="default",
  85. )
  86. specaug_choices = ClassChoices(
  87. name="specaug",
  88. classes=dict(
  89. specaug=SpecAug,
  90. specaug_lfr=SpecAugLFR,
  91. ),
  92. type_check=AbsSpecAug,
  93. default=None,
  94. optional=True,
  95. )
  96. normalize_choices = ClassChoices(
  97. "normalize",
  98. classes=dict(
  99. global_mvn=GlobalMVN,
  100. utterance_mvn=UtteranceMVN,
  101. ),
  102. type_check=AbsNormalize,
  103. default=None,
  104. optional=True,
  105. )
  106. model_choices = ClassChoices(
  107. "model",
  108. classes=dict(
  109. e2evad=E2EVadModel,
  110. ),
  111. type_check=object,
  112. default="e2evad",
  113. )
  114. encoder_choices = ClassChoices(
  115. "encoder",
  116. classes=dict(
  117. fsmn=FSMN,
  118. ),
  119. type_check=torch.nn.Module,
  120. default="fsmn",
  121. )
  122. class VADTask(AbsTask):
  123. # If you need more than one optimizers, change this value
  124. num_optimizers: int = 1
  125. # Add variable objects configurations
  126. class_choices_list = [
  127. # --frontend and --frontend_conf
  128. frontend_choices,
  129. # --model and --model_conf
  130. model_choices,
  131. ]
  132. # If you need to modify train() or eval() procedures, change Trainer class here
  133. trainer = Trainer
  134. @classmethod
  135. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  136. group = parser.add_argument_group(description="Task related")
  137. # NOTE(kamo): add_arguments(..., required=True) can't be used
  138. # to provide --print_config mode. Instead of it, do as
  139. # required = parser.get_default("required")
  140. # required += ["token_list"]
  141. group.add_argument(
  142. "--init",
  143. type=lambda x: str_or_none(x.lower()),
  144. default=None,
  145. help="The initialization method",
  146. choices=[
  147. "chainer",
  148. "xavier_uniform",
  149. "xavier_normal",
  150. "kaiming_uniform",
  151. "kaiming_normal",
  152. None,
  153. ],
  154. )
  155. group.add_argument(
  156. "--input_size",
  157. type=int_or_none,
  158. default=None,
  159. help="The number of input dimension of the feature",
  160. )
  161. group = parser.add_argument_group(description="Preprocess related")
  162. parser.add_argument(
  163. "--speech_volume_normalize",
  164. type=float_or_none,
  165. default=None,
  166. help="Scale the maximum amplitude to the given value.",
  167. )
  168. parser.add_argument(
  169. "--rir_scp",
  170. type=str_or_none,
  171. default=None,
  172. help="The file path of rir scp file.",
  173. )
  174. parser.add_argument(
  175. "--rir_apply_prob",
  176. type=float,
  177. default=1.0,
  178. help="THe probability for applying RIR convolution.",
  179. )
  180. parser.add_argument(
  181. "--cmvn_file",
  182. type=str_or_none,
  183. default=None,
  184. help="The file path of noise scp file.",
  185. )
  186. parser.add_argument(
  187. "--noise_scp",
  188. type=str_or_none,
  189. default=None,
  190. help="The file path of noise scp file.",
  191. )
  192. parser.add_argument(
  193. "--noise_apply_prob",
  194. type=float,
  195. default=1.0,
  196. help="The probability applying Noise adding.",
  197. )
  198. parser.add_argument(
  199. "--noise_db_range",
  200. type=str,
  201. default="13_15",
  202. help="The range of noise decibel level.",
  203. )
  204. for class_choices in cls.class_choices_list:
  205. # Append --<name> and --<name>_conf.
  206. # e.g. --encoder and --encoder_conf
  207. class_choices.add_arguments(group)
  208. @classmethod
  209. def build_collate_fn(
  210. cls, args: argparse.Namespace, train: bool
  211. ) -> Callable[
  212. [Collection[Tuple[str, Dict[str, np.ndarray]]]],
  213. Tuple[List[str], Dict[str, torch.Tensor]],
  214. ]:
  215. assert check_argument_types()
  216. # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
  217. return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
  218. @classmethod
  219. def build_preprocess_fn(
  220. cls, args: argparse.Namespace, train: bool
  221. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  222. assert check_argument_types()
  223. # if args.use_preprocessor:
  224. # retval = CommonPreprocessor(
  225. # train=train,
  226. # # NOTE(kamo): Check attribute existence for backward compatibility
  227. # rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
  228. # rir_apply_prob=args.rir_apply_prob
  229. # if hasattr(args, "rir_apply_prob")
  230. # else 1.0,
  231. # noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
  232. # noise_apply_prob=args.noise_apply_prob
  233. # if hasattr(args, "noise_apply_prob")
  234. # else 1.0,
  235. # noise_db_range=args.noise_db_range
  236. # if hasattr(args, "noise_db_range")
  237. # else "13_15",
  238. # speech_volume_normalize=args.speech_volume_normalize
  239. # if hasattr(args, "rir_scp")
  240. # else None,
  241. # )
  242. # else:
  243. # retval = None
  244. retval = None
  245. assert check_return_type(retval)
  246. return retval
  247. @classmethod
  248. def required_data_names(
  249. cls, train: bool = True, inference: bool = False
  250. ) -> Tuple[str, ...]:
  251. if not inference:
  252. retval = ("speech", "text")
  253. else:
  254. # Recognition mode
  255. retval = ("speech",)
  256. return retval
  257. @classmethod
  258. def optional_data_names(
  259. cls, train: bool = True, inference: bool = False
  260. ) -> Tuple[str, ...]:
  261. retval = ()
  262. assert check_return_type(retval)
  263. return retval
  264. @classmethod
  265. def build_model(cls, args: argparse.Namespace):
  266. assert check_argument_types()
  267. # 4. Encoder
  268. encoder_class = encoder_choices.get_class(args.encoder)
  269. encoder = encoder_class(**args.encoder_conf)
  270. # 5. Build model
  271. try:
  272. model_class = model_choices.get_class(args.model)
  273. except AttributeError:
  274. model_class = model_choices.get_class("e2evad")
  275. # 1. frontend
  276. if args.input_size is None:
  277. # Extract features in the model
  278. frontend_class = frontend_choices.get_class(args.frontend)
  279. if args.frontend == 'wav_frontend':
  280. frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
  281. else:
  282. frontend = frontend_class(**args.frontend_conf)
  283. input_size = frontend.output_size()
  284. else:
  285. # Give features from data-loader
  286. args.frontend = None
  287. args.frontend_conf = {}
  288. frontend = None
  289. input_size = args.input_size
  290. model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
  291. return model
  292. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  293. @classmethod
  294. def build_model_from_file(
  295. cls,
  296. config_file: Union[Path, str] = None,
  297. model_file: Union[Path, str] = None,
  298. device: str = "cpu",
  299. cmvn_file: Union[Path, str] = None,
  300. ):
  301. """Build model from the files.
  302. This method is used for inference or fine-tuning.
  303. Args:
  304. config_file: The yaml file saved when training.
  305. model_file: The model file saved when training.
  306. device: Device type, "cpu", "cuda", or "cuda:N".
  307. """
  308. assert check_argument_types()
  309. if config_file is None:
  310. assert model_file is not None, (
  311. "The argument 'model_file' must be provided "
  312. "if the argument 'config_file' is not specified."
  313. )
  314. config_file = Path(model_file).parent / "config.yaml"
  315. else:
  316. config_file = Path(config_file)
  317. with config_file.open("r", encoding="utf-8") as f:
  318. args = yaml.safe_load(f)
  319. #if cmvn_file is not None:
  320. args["cmvn_file"] = cmvn_file
  321. args = argparse.Namespace(**args)
  322. model = cls.build_model(args)
  323. model.to(device)
  324. model_dict = dict()
  325. model_name_pth = None
  326. if model_file is not None:
  327. logging.info("model_file is {}".format(model_file))
  328. if device == "cuda":
  329. device = f"cuda:{torch.cuda.current_device()}"
  330. model_dir = os.path.dirname(model_file)
  331. model_name = os.path.basename(model_file)
  332. model_dict = torch.load(model_file, map_location=device)
  333. model.encoder.load_state_dict(model_dict)
  334. return model, args