vad_inference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import argparse
  2. import logging
  3. import sys
  4. from pathlib import Path
  5. from typing import Any
  6. from typing import List
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Tuple
  10. from typing import Union
  11. from typing import Dict
  12. import numpy as np
  13. import torch
  14. from typeguard import check_argument_types
  15. from typeguard import check_return_type
  16. from funasr.fileio.datadir_writer import DatadirWriter
  17. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  18. from funasr.modules.subsampling import TooShortUttError
  19. from funasr.tasks.vad import VADTask
  20. from funasr.torch_utils.device_funcs import to_device
  21. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  22. from funasr.utils import config_argparse
  23. from funasr.utils.cli_utils import get_commandline_args
  24. from funasr.utils.types import str2bool
  25. from funasr.utils.types import str2triple_str
  26. from funasr.utils.types import str_or_none
  27. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  28. from funasr.models.frontend.wav_frontend import WavFrontend
  29. header_colors = '\033[95m'
  30. end_colors = '\033[0m'
  31. global_asr_language: str = 'zh-cn'
  32. global_sample_rate: Union[int, Dict[Any, int]] = {
  33. 'audio_fs': 16000,
  34. 'model_fs': 16000
  35. }
  36. class Speech2VadSegment:
  37. """Speech2VadSegment class
  38. Examples:
  39. >>> import soundfile
  40. >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
  41. >>> audio, rate = soundfile.read("speech.wav")
  42. >>> speech2segment(audio)
  43. [[10, 230], [245, 450], ...]
  44. """
  45. def __init__(
  46. self,
  47. vad_infer_config: Union[Path, str] = None,
  48. vad_model_file: Union[Path, str] = None,
  49. vad_cmvn_file: Union[Path, str] = None,
  50. device: str = "cpu",
  51. batch_size: int = 1,
  52. dtype: str = "float32",
  53. **kwargs,
  54. ):
  55. assert check_argument_types()
  56. # 1. Build vad model
  57. vad_model, vad_infer_args = VADTask.build_model_from_file(
  58. vad_infer_config, vad_model_file, device
  59. )
  60. frontend = None
  61. if vad_infer_args.frontend is not None:
  62. frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
  63. logging.info("vad_model: {}".format(vad_model))
  64. logging.info("vad_infer_args: {}".format(vad_infer_args))
  65. vad_model.to(dtype=getattr(torch, dtype)).eval()
  66. self.vad_model = vad_model
  67. self.vad_infer_args = vad_infer_args
  68. self.device = device
  69. self.dtype = dtype
  70. self.frontend = frontend
  71. @torch.no_grad()
  72. def __call__(
  73. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  74. ) -> List[List[int]]:
  75. """Inference
  76. Args:
  77. speech: Input speech data
  78. Returns:
  79. text, token, token_int, hyp
  80. """
  81. assert check_argument_types()
  82. # Input as audio signal
  83. if isinstance(speech, np.ndarray):
  84. speech = torch.tensor(speech)
  85. if self.frontend is not None:
  86. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  87. feats = to_device(feats, device=self.device)
  88. feats_len = feats_len.int()
  89. else:
  90. raise Exception("Need to extract feats first, please configure frontend configuration")
  91. batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech}
  92. # a. To device
  93. batch = to_device(batch, device=self.device)
  94. # b. Forward Encoder
  95. segments = self.vad_model(**batch)
  96. return segments
  97. def inference(
  98. batch_size: int,
  99. ngpu: int,
  100. log_level: Union[int, str],
  101. data_path_and_name_and_type,
  102. vad_infer_config: Optional[str],
  103. vad_model_file: Optional[str],
  104. vad_cmvn_file: Optional[str] = None,
  105. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  106. key_file: Optional[str] = None,
  107. allow_variable_data_keys: bool = False,
  108. output_dir: Optional[str] = None,
  109. dtype: str = "float32",
  110. seed: int = 0,
  111. num_workers: int = 1,
  112. **kwargs,
  113. ):
  114. inference_pipeline = inference_modelscope(
  115. batch_size=batch_size,
  116. ngpu=ngpu,
  117. log_level=log_level,
  118. vad_infer_config=vad_infer_config,
  119. vad_model_file=vad_model_file,
  120. vad_cmvn_file=vad_cmvn_file,
  121. key_file=key_file,
  122. allow_variable_data_keys=allow_variable_data_keys,
  123. output_dir=output_dir,
  124. dtype=dtype,
  125. seed=seed,
  126. num_workers=num_workers,
  127. **kwargs,
  128. )
  129. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  130. def inference_modelscope(
  131. batch_size: int,
  132. ngpu: int,
  133. log_level: Union[int, str],
  134. #data_path_and_name_and_type,
  135. vad_infer_config: Optional[str],
  136. vad_model_file: Optional[str],
  137. vad_cmvn_file: Optional[str] = None,
  138. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  139. key_file: Optional[str] = None,
  140. allow_variable_data_keys: bool = False,
  141. output_dir: Optional[str] = None,
  142. dtype: str = "float32",
  143. seed: int = 0,
  144. num_workers: int = 1,
  145. param_dict: dict = None,
  146. **kwargs,
  147. ):
  148. assert check_argument_types()
  149. if batch_size > 1:
  150. raise NotImplementedError("batch decoding is not implemented")
  151. if ngpu > 1:
  152. raise NotImplementedError("only single GPU decoding is supported")
  153. logging.basicConfig(
  154. level=log_level,
  155. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  156. )
  157. if ngpu >= 1 and torch.cuda.is_available():
  158. device = "cuda"
  159. else:
  160. device = "cpu"
  161. # 1. Set random-seed
  162. set_all_random_seed(seed)
  163. # 2. Build speech2vadsegment
  164. speech2vadsegment_kwargs = dict(
  165. vad_infer_config=vad_infer_config,
  166. vad_model_file=vad_model_file,
  167. vad_cmvn_file=vad_cmvn_file,
  168. device=device,
  169. dtype=dtype,
  170. )
  171. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  172. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  173. def _forward(
  174. data_path_and_name_and_type,
  175. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  176. output_dir_v2: Optional[str] = None,
  177. fs: dict = None,
  178. param_dict: dict = None,
  179. ):
  180. # 3. Build data-iterator
  181. loader = VADTask.build_streaming_iterator(
  182. data_path_and_name_and_type,
  183. dtype=dtype,
  184. batch_size=batch_size,
  185. key_file=key_file,
  186. num_workers=num_workers,
  187. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  188. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  189. allow_variable_data_keys=allow_variable_data_keys,
  190. inference=True,
  191. )
  192. finish_count = 0
  193. file_count = 1
  194. # 7 .Start for-loop
  195. # FIXME(kamo): The output format should be discussed about
  196. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  197. if output_path is not None:
  198. writer = DatadirWriter(output_path)
  199. ibest_writer = writer[f"1best_recog"]
  200. else:
  201. writer = None
  202. ibest_writer = None
  203. vad_results = []
  204. for keys, batch in loader:
  205. assert isinstance(batch, dict), type(batch)
  206. assert all(isinstance(s, str) for s in keys), keys
  207. _bs = len(next(iter(batch.values())))
  208. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  209. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  210. # do vad segment
  211. results = speech2vadsegment(**batch)
  212. for i, _ in enumerate(keys):
  213. item = {'key': keys[i], 'value': results[i]}
  214. vad_results.append(item)
  215. if writer is not None:
  216. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  217. return vad_results
  218. return _forward
  219. def get_parser():
  220. parser = config_argparse.ArgumentParser(
  221. description="VAD Decoding",
  222. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  223. )
  224. # Note(kamo): Use '_' instead of '-' as separator.
  225. # '-' is confusing if written in yaml.
  226. parser.add_argument(
  227. "--log_level",
  228. type=lambda x: x.upper(),
  229. default="INFO",
  230. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  231. help="The verbose level of logging",
  232. )
  233. parser.add_argument("--output_dir", type=str, required=False)
  234. parser.add_argument(
  235. "--ngpu",
  236. type=int,
  237. default=0,
  238. help="The number of gpus. 0 indicates CPU mode",
  239. )
  240. parser.add_argument(
  241. "--gpuid_list",
  242. type=str,
  243. default="",
  244. help="The visible gpus",
  245. )
  246. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  247. parser.add_argument(
  248. "--dtype",
  249. default="float32",
  250. choices=["float16", "float32", "float64"],
  251. help="Data type",
  252. )
  253. parser.add_argument(
  254. "--num_workers",
  255. type=int,
  256. default=1,
  257. help="The number of workers used for DataLoader",
  258. )
  259. group = parser.add_argument_group("Input data related")
  260. group.add_argument(
  261. "--data_path_and_name_and_type",
  262. type=str2triple_str,
  263. required=False,
  264. action="append",
  265. )
  266. group.add_argument("--raw_inputs", type=list, default=None)
  267. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  268. group.add_argument("--key_file", type=str_or_none)
  269. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  270. group = parser.add_argument_group("The model configuration related")
  271. group.add_argument(
  272. "--vad_infer_config",
  273. type=str,
  274. help="VAD infer configuration",
  275. )
  276. group.add_argument(
  277. "--vad_model_file",
  278. type=str,
  279. help="VAD model parameter file",
  280. )
  281. group.add_argument(
  282. "--vad_cmvn_file",
  283. type=str,
  284. help="Global cmvn file",
  285. )
  286. group = parser.add_argument_group("infer related")
  287. group.add_argument(
  288. "--batch_size",
  289. type=int,
  290. default=1,
  291. help="The batch size for inference",
  292. )
  293. return parser
  294. def main(cmd=None):
  295. print(get_commandline_args(), file=sys.stderr)
  296. parser = get_parser()
  297. args = parser.parse_args(cmd)
  298. kwargs = vars(args)
  299. kwargs.pop("config", None)
  300. inference(**kwargs)
  301. if __name__ == "__main__":
  302. main()