vad_inference.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. import argparse
  2. import logging
  3. import sys
  4. import json
  5. from pathlib import Path
  6. from typing import Any
  7. from typing import List
  8. from typing import Optional
  9. from typing import Sequence
  10. from typing import Tuple
  11. from typing import Union
  12. from typing import Dict
  13. import math
  14. import numpy as np
  15. import torch
  16. from typeguard import check_argument_types
  17. from typeguard import check_return_type
  18. from funasr.fileio.datadir_writer import DatadirWriter
  19. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  20. from funasr.modules.subsampling import TooShortUttError
  21. from funasr.tasks.vad import VADTask
  22. from funasr.torch_utils.device_funcs import to_device
  23. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  24. from funasr.utils import config_argparse
  25. from funasr.utils.cli_utils import get_commandline_args
  26. from funasr.utils.types import str2bool
  27. from funasr.utils.types import str2triple_str
  28. from funasr.utils.types import str_or_none
  29. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  30. from funasr.models.frontend.wav_frontend import WavFrontend
  31. header_colors = '\033[95m'
  32. end_colors = '\033[0m'
  33. global_asr_language: str = 'zh-cn'
  34. global_sample_rate: Union[int, Dict[Any, int]] = {
  35. 'audio_fs': 16000,
  36. 'model_fs': 16000
  37. }
  38. class Speech2VadSegment:
  39. """Speech2VadSegment class
  40. Examples:
  41. >>> import soundfile
  42. >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
  43. >>> audio, rate = soundfile.read("speech.wav")
  44. >>> speech2segment(audio)
  45. [[10, 230], [245, 450], ...]
  46. """
  47. def __init__(
  48. self,
  49. vad_infer_config: Union[Path, str] = None,
  50. vad_model_file: Union[Path, str] = None,
  51. vad_cmvn_file: Union[Path, str] = None,
  52. device: str = "cpu",
  53. batch_size: int = 1,
  54. dtype: str = "float32",
  55. **kwargs,
  56. ):
  57. assert check_argument_types()
  58. # 1. Build vad model
  59. vad_model, vad_infer_args = VADTask.build_model_from_file(
  60. vad_infer_config, vad_model_file, device
  61. )
  62. frontend = None
  63. if vad_infer_args.frontend is not None:
  64. frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
  65. logging.info("vad_model: {}".format(vad_model))
  66. logging.info("vad_infer_args: {}".format(vad_infer_args))
  67. vad_model.to(dtype=getattr(torch, dtype)).eval()
  68. self.vad_model = vad_model
  69. self.vad_infer_args = vad_infer_args
  70. self.device = device
  71. self.dtype = dtype
  72. self.frontend = frontend
  73. self.batch_size = batch_size
  74. @torch.no_grad()
  75. def __call__(
  76. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  77. in_cache: Dict[str, torch.Tensor] = dict()
  78. ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
  79. """Inference
  80. Args:
  81. speech: Input speech data
  82. Returns:
  83. text, token, token_int, hyp
  84. """
  85. assert check_argument_types()
  86. # Input as audio signal
  87. if isinstance(speech, np.ndarray):
  88. speech = torch.tensor(speech)
  89. if self.frontend is not None:
  90. self.frontend.filter_length_max = math.inf
  91. fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
  92. feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
  93. fbanks = to_device(fbanks, device=self.device)
  94. feats = to_device(feats, device=self.device)
  95. feats_len = feats_len.int()
  96. else:
  97. raise Exception("Need to extract feats first, please configure frontend configuration")
  98. # b. Forward Encoder streaming
  99. t_offset = 0
  100. step = min(feats_len.max(), 6000)
  101. segments = [[]] * self.batch_size
  102. for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
  103. if t_offset + step >= feats_len - 1:
  104. step = feats_len - t_offset
  105. is_final = True
  106. else:
  107. is_final = False
  108. batch = {
  109. "feats": feats[:, t_offset:t_offset + step, :],
  110. "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
  111. "is_final": is_final,
  112. "in_cache": in_cache
  113. }
  114. # a. To device
  115. batch = to_device(batch, device=self.device)
  116. segments_part, in_cache = self.vad_model(**batch)
  117. if segments_part:
  118. for batch_num in range(0, self.batch_size):
  119. segments[batch_num] += segments_part[batch_num]
  120. return fbanks, segments
  121. def inference(
  122. batch_size: int,
  123. ngpu: int,
  124. log_level: Union[int, str],
  125. data_path_and_name_and_type,
  126. vad_infer_config: Optional[str],
  127. vad_model_file: Optional[str],
  128. vad_cmvn_file: Optional[str] = None,
  129. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  130. key_file: Optional[str] = None,
  131. allow_variable_data_keys: bool = False,
  132. output_dir: Optional[str] = None,
  133. dtype: str = "float32",
  134. seed: int = 0,
  135. num_workers: int = 1,
  136. **kwargs,
  137. ):
  138. inference_pipeline = inference_modelscope(
  139. batch_size=batch_size,
  140. ngpu=ngpu,
  141. log_level=log_level,
  142. vad_infer_config=vad_infer_config,
  143. vad_model_file=vad_model_file,
  144. vad_cmvn_file=vad_cmvn_file,
  145. key_file=key_file,
  146. allow_variable_data_keys=allow_variable_data_keys,
  147. output_dir=output_dir,
  148. dtype=dtype,
  149. seed=seed,
  150. num_workers=num_workers,
  151. **kwargs,
  152. )
  153. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  154. def inference_modelscope(
  155. batch_size: int,
  156. ngpu: int,
  157. log_level: Union[int, str],
  158. # data_path_and_name_and_type,
  159. vad_infer_config: Optional[str],
  160. vad_model_file: Optional[str],
  161. vad_cmvn_file: Optional[str] = None,
  162. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  163. key_file: Optional[str] = None,
  164. allow_variable_data_keys: bool = False,
  165. output_dir: Optional[str] = None,
  166. dtype: str = "float32",
  167. seed: int = 0,
  168. num_workers: int = 1,
  169. **kwargs,
  170. ):
  171. assert check_argument_types()
  172. if batch_size > 1:
  173. raise NotImplementedError("batch decoding is not implemented")
  174. if ngpu > 1:
  175. raise NotImplementedError("only single GPU decoding is supported")
  176. logging.basicConfig(
  177. level=log_level,
  178. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  179. )
  180. if ngpu >= 1 and torch.cuda.is_available():
  181. device = "cuda"
  182. else:
  183. device = "cpu"
  184. # 1. Set random-seed
  185. set_all_random_seed(seed)
  186. # 2. Build speech2vadsegment
  187. speech2vadsegment_kwargs = dict(
  188. vad_infer_config=vad_infer_config,
  189. vad_model_file=vad_model_file,
  190. vad_cmvn_file=vad_cmvn_file,
  191. device=device,
  192. dtype=dtype,
  193. )
  194. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  195. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  196. def _forward(
  197. data_path_and_name_and_type,
  198. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  199. output_dir_v2: Optional[str] = None,
  200. fs: dict = None,
  201. param_dict: dict = None
  202. ):
  203. # 3. Build data-iterator
  204. if data_path_and_name_and_type is None and raw_inputs is not None:
  205. if isinstance(raw_inputs, torch.Tensor):
  206. raw_inputs = raw_inputs.numpy()
  207. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  208. loader = VADTask.build_streaming_iterator(
  209. data_path_and_name_and_type,
  210. dtype=dtype,
  211. batch_size=batch_size,
  212. key_file=key_file,
  213. num_workers=num_workers,
  214. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  215. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  216. allow_variable_data_keys=allow_variable_data_keys,
  217. inference=True,
  218. )
  219. finish_count = 0
  220. file_count = 1
  221. # 7 .Start for-loop
  222. # FIXME(kamo): The output format should be discussed about
  223. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  224. if output_path is not None:
  225. writer = DatadirWriter(output_path)
  226. ibest_writer = writer[f"1best_recog"]
  227. else:
  228. writer = None
  229. ibest_writer = None
  230. vad_results = []
  231. for keys, batch in loader:
  232. assert isinstance(batch, dict), type(batch)
  233. assert all(isinstance(s, str) for s in keys), keys
  234. _bs = len(next(iter(batch.values())))
  235. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  236. # do vad segment
  237. _, results = speech2vadsegment(**batch)
  238. for i, _ in enumerate(keys):
  239. results[i] = json.dumps(results[i])
  240. item = {'key': keys[i], 'value': results[i]}
  241. vad_results.append(item)
  242. if writer is not None:
  243. results[i] = json.loads(results[i])
  244. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  245. return vad_results
  246. return _forward
  247. def get_parser():
  248. parser = config_argparse.ArgumentParser(
  249. description="VAD Decoding",
  250. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  251. )
  252. # Note(kamo): Use '_' instead of '-' as separator.
  253. # '-' is confusing if written in yaml.
  254. parser.add_argument(
  255. "--log_level",
  256. type=lambda x: x.upper(),
  257. default="INFO",
  258. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  259. help="The verbose level of logging",
  260. )
  261. parser.add_argument("--output_dir", type=str, required=False)
  262. parser.add_argument(
  263. "--ngpu",
  264. type=int,
  265. default=0,
  266. help="The number of gpus. 0 indicates CPU mode",
  267. )
  268. parser.add_argument(
  269. "--gpuid_list",
  270. type=str,
  271. default="",
  272. help="The visible gpus",
  273. )
  274. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  275. parser.add_argument(
  276. "--dtype",
  277. default="float32",
  278. choices=["float16", "float32", "float64"],
  279. help="Data type",
  280. )
  281. parser.add_argument(
  282. "--num_workers",
  283. type=int,
  284. default=1,
  285. help="The number of workers used for DataLoader",
  286. )
  287. group = parser.add_argument_group("Input data related")
  288. group.add_argument(
  289. "--data_path_and_name_and_type",
  290. type=str2triple_str,
  291. required=False,
  292. action="append",
  293. )
  294. group.add_argument("--raw_inputs", type=list, default=None)
  295. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  296. group.add_argument("--key_file", type=str_or_none)
  297. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  298. group = parser.add_argument_group("The model configuration related")
  299. group.add_argument(
  300. "--vad_infer_config",
  301. type=str,
  302. help="VAD infer configuration",
  303. )
  304. group.add_argument(
  305. "--vad_model_file",
  306. type=str,
  307. help="VAD model parameter file",
  308. )
  309. group.add_argument(
  310. "--vad_cmvn_file",
  311. type=str,
  312. help="Global cmvn file",
  313. )
  314. group = parser.add_argument_group("infer related")
  315. group.add_argument(
  316. "--batch_size",
  317. type=int,
  318. default=1,
  319. help="The batch size for inference",
  320. )
  321. return parser
  322. def main(cmd=None):
  323. print(get_commandline_args(), file=sys.stderr)
  324. parser = get_parser()
  325. args = parser.parse_args(cmd)
  326. kwargs = vars(args)
  327. kwargs.pop("config", None)
  328. inference(**kwargs)
  329. if __name__ == "__main__":
  330. main()