vad_inference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import argparse
  2. import logging
  3. import sys
  4. import json
  5. from pathlib import Path
  6. from typing import Any
  7. from typing import List
  8. from typing import Optional
  9. from typing import Sequence
  10. from typing import Tuple
  11. from typing import Union
  12. from typing import Dict
  13. import numpy as np
  14. import torch
  15. from typeguard import check_argument_types
  16. from typeguard import check_return_type
  17. from funasr.fileio.datadir_writer import DatadirWriter
  18. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  19. from funasr.modules.subsampling import TooShortUttError
  20. from funasr.tasks.vad import VADTask
  21. from funasr.torch_utils.device_funcs import to_device
  22. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.utils import config_argparse
  24. from funasr.utils.cli_utils import get_commandline_args
  25. from funasr.utils.types import str2bool
  26. from funasr.utils.types import str2triple_str
  27. from funasr.utils.types import str_or_none
  28. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  29. from funasr.models.frontend.wav_frontend import WavFrontend
  30. header_colors = '\033[95m'
  31. end_colors = '\033[0m'
  32. global_asr_language: str = 'zh-cn'
  33. global_sample_rate: Union[int, Dict[Any, int]] = {
  34. 'audio_fs': 16000,
  35. 'model_fs': 16000
  36. }
  37. class Speech2VadSegment:
  38. """Speech2VadSegment class
  39. Examples:
  40. >>> import soundfile
  41. >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
  42. >>> audio, rate = soundfile.read("speech.wav")
  43. >>> speech2segment(audio)
  44. [[10, 230], [245, 450], ...]
  45. """
  46. def __init__(
  47. self,
  48. vad_infer_config: Union[Path, str] = None,
  49. vad_model_file: Union[Path, str] = None,
  50. vad_cmvn_file: Union[Path, str] = None,
  51. device: str = "cpu",
  52. batch_size: int = 1,
  53. dtype: str = "float32",
  54. **kwargs,
  55. ):
  56. assert check_argument_types()
  57. # 1. Build vad model
  58. vad_model, vad_infer_args = VADTask.build_model_from_file(
  59. vad_infer_config, vad_model_file, device
  60. )
  61. frontend = None
  62. if vad_infer_args.frontend is not None:
  63. frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
  64. logging.info("vad_model: {}".format(vad_model))
  65. logging.info("vad_infer_args: {}".format(vad_infer_args))
  66. vad_model.to(dtype=getattr(torch, dtype)).eval()
  67. self.vad_model = vad_model
  68. self.vad_infer_args = vad_infer_args
  69. self.device = device
  70. self.dtype = dtype
  71. self.frontend = frontend
  72. self.batch_size = batch_size
  73. @torch.no_grad()
  74. def __call__(
  75. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  76. ) -> List[List[int]]:
  77. """Inference
  78. Args:
  79. speech: Input speech data
  80. Returns:
  81. text, token, token_int, hyp
  82. """
  83. assert check_argument_types()
  84. # Input as audio signal
  85. if isinstance(speech, np.ndarray):
  86. speech = torch.tensor(speech)
  87. if self.frontend is not None:
  88. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  89. feats = to_device(feats, device=self.device)
  90. feats_len = feats_len.int()
  91. else:
  92. raise Exception("Need to extract feats first, please configure frontend configuration")
  93. # b. Forward Encoder streaming
  94. t_offset = 0
  95. step = min(feats_len, 6000)
  96. segments = [[]] * self.batch_size
  97. for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
  98. if t_offset + step >= feats_len - 1:
  99. step = feats_len - t_offset
  100. is_final_send = True
  101. else:
  102. is_final_send = False
  103. batch = {
  104. "feats": feats[:, t_offset:t_offset + step, :],
  105. "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
  106. "is_final_send": is_final_send
  107. }
  108. # a. To device
  109. batch = to_device(batch, device=self.device)
  110. segments_part = self.vad_model(**batch)
  111. if segments_part:
  112. for batch_num in range(0, self.batch_size):
  113. segments[batch_num] += segments_part[batch_num]
  114. return segments
  115. def inference(
  116. batch_size: int,
  117. ngpu: int,
  118. log_level: Union[int, str],
  119. data_path_and_name_and_type,
  120. vad_infer_config: Optional[str],
  121. vad_model_file: Optional[str],
  122. vad_cmvn_file: Optional[str] = None,
  123. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  124. key_file: Optional[str] = None,
  125. allow_variable_data_keys: bool = False,
  126. output_dir: Optional[str] = None,
  127. dtype: str = "float32",
  128. seed: int = 0,
  129. num_workers: int = 1,
  130. **kwargs,
  131. ):
  132. inference_pipeline = inference_modelscope(
  133. batch_size=batch_size,
  134. ngpu=ngpu,
  135. log_level=log_level,
  136. vad_infer_config=vad_infer_config,
  137. vad_model_file=vad_model_file,
  138. vad_cmvn_file=vad_cmvn_file,
  139. key_file=key_file,
  140. allow_variable_data_keys=allow_variable_data_keys,
  141. output_dir=output_dir,
  142. dtype=dtype,
  143. seed=seed,
  144. num_workers=num_workers,
  145. **kwargs,
  146. )
  147. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  148. def inference_modelscope(
  149. batch_size: int,
  150. ngpu: int,
  151. log_level: Union[int, str],
  152. # data_path_and_name_and_type,
  153. vad_infer_config: Optional[str],
  154. vad_model_file: Optional[str],
  155. vad_cmvn_file: Optional[str] = None,
  156. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  157. key_file: Optional[str] = None,
  158. allow_variable_data_keys: bool = False,
  159. output_dir: Optional[str] = None,
  160. dtype: str = "float32",
  161. seed: int = 0,
  162. num_workers: int = 1,
  163. **kwargs,
  164. ):
  165. assert check_argument_types()
  166. if batch_size > 1:
  167. raise NotImplementedError("batch decoding is not implemented")
  168. if ngpu > 1:
  169. raise NotImplementedError("only single GPU decoding is supported")
  170. logging.basicConfig(
  171. level=log_level,
  172. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  173. )
  174. if ngpu >= 1 and torch.cuda.is_available():
  175. device = "cuda"
  176. else:
  177. device = "cpu"
  178. # 1. Set random-seed
  179. set_all_random_seed(seed)
  180. # 2. Build speech2vadsegment
  181. speech2vadsegment_kwargs = dict(
  182. vad_infer_config=vad_infer_config,
  183. vad_model_file=vad_model_file,
  184. vad_cmvn_file=vad_cmvn_file,
  185. device=device,
  186. dtype=dtype,
  187. )
  188. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  189. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  190. def _forward(
  191. data_path_and_name_and_type,
  192. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  193. output_dir_v2: Optional[str] = None,
  194. fs: dict = None,
  195. param_dict: dict = None,
  196. ):
  197. # 3. Build data-iterator
  198. loader = VADTask.build_streaming_iterator(
  199. data_path_and_name_and_type,
  200. dtype=dtype,
  201. batch_size=batch_size,
  202. key_file=key_file,
  203. num_workers=num_workers,
  204. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  205. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  206. allow_variable_data_keys=allow_variable_data_keys,
  207. inference=True,
  208. )
  209. finish_count = 0
  210. file_count = 1
  211. # 7 .Start for-loop
  212. # FIXME(kamo): The output format should be discussed about
  213. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  214. if output_path is not None:
  215. writer = DatadirWriter(output_path)
  216. ibest_writer = writer[f"1best_recog"]
  217. else:
  218. writer = None
  219. ibest_writer = None
  220. vad_results = []
  221. for keys, batch in loader:
  222. assert isinstance(batch, dict), type(batch)
  223. assert all(isinstance(s, str) for s in keys), keys
  224. _bs = len(next(iter(batch.values())))
  225. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  226. # do vad segment
  227. results = speech2vadsegment(**batch)
  228. for i, _ in enumerate(keys):
  229. results[i] = json.dumps(results[i])
  230. item = {'key': keys[i], 'value': results[i]}
  231. vad_results.append(item)
  232. if writer is not None:
  233. results[i] = json.loads(results[i])
  234. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  235. return vad_results
  236. return _forward
  237. def get_parser():
  238. parser = config_argparse.ArgumentParser(
  239. description="VAD Decoding",
  240. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  241. )
  242. # Note(kamo): Use '_' instead of '-' as separator.
  243. # '-' is confusing if written in yaml.
  244. parser.add_argument(
  245. "--log_level",
  246. type=lambda x: x.upper(),
  247. default="INFO",
  248. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  249. help="The verbose level of logging",
  250. )
  251. parser.add_argument("--output_dir", type=str, required=False)
  252. parser.add_argument(
  253. "--ngpu",
  254. type=int,
  255. default=0,
  256. help="The number of gpus. 0 indicates CPU mode",
  257. )
  258. parser.add_argument(
  259. "--gpuid_list",
  260. type=str,
  261. default="",
  262. help="The visible gpus",
  263. )
  264. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  265. parser.add_argument(
  266. "--dtype",
  267. default="float32",
  268. choices=["float16", "float32", "float64"],
  269. help="Data type",
  270. )
  271. parser.add_argument(
  272. "--num_workers",
  273. type=int,
  274. default=1,
  275. help="The number of workers used for DataLoader",
  276. )
  277. group = parser.add_argument_group("Input data related")
  278. group.add_argument(
  279. "--data_path_and_name_and_type",
  280. type=str2triple_str,
  281. required=False,
  282. action="append",
  283. )
  284. group.add_argument("--raw_inputs", type=list, default=None)
  285. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  286. group.add_argument("--key_file", type=str_or_none)
  287. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  288. group = parser.add_argument_group("The model configuration related")
  289. group.add_argument(
  290. "--vad_infer_config",
  291. type=str,
  292. help="VAD infer configuration",
  293. )
  294. group.add_argument(
  295. "--vad_model_file",
  296. type=str,
  297. help="VAD model parameter file",
  298. )
  299. group.add_argument(
  300. "--vad_cmvn_file",
  301. type=str,
  302. help="Global cmvn file",
  303. )
  304. group = parser.add_argument_group("infer related")
  305. group.add_argument(
  306. "--batch_size",
  307. type=int,
  308. default=1,
  309. help="The batch size for inference",
  310. )
  311. return parser
  312. def main(cmd=None):
  313. print(get_commandline_args(), file=sys.stderr)
  314. parser = get_parser()
  315. args = parser.parse_args(cmd)
  316. kwargs = vars(args)
  317. kwargs.pop("config", None)
  318. inference(**kwargs)
  319. if __name__ == "__main__":
  320. main()