vad_inference_launch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. torch.set_num_threads(1)
  7. import argparse
  8. import logging
  9. import os
  10. import sys
  11. import json
  12. from typing import Optional
  13. from typing import Union
  14. import numpy as np
  15. import torch
  16. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  17. from funasr.fileio.datadir_writer import DatadirWriter
  18. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  19. from funasr.utils import config_argparse
  20. from funasr.utils.cli_utils import get_commandline_args
  21. from funasr.utils.types import str2bool
  22. from funasr.utils.types import str2triple_str
  23. from funasr.utils.types import str_or_none
  24. from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
  25. def inference_vad(
  26. batch_size: int,
  27. ngpu: int,
  28. log_level: Union[int, str],
  29. # data_path_and_name_and_type,
  30. vad_infer_config: Optional[str],
  31. vad_model_file: Optional[str],
  32. vad_cmvn_file: Optional[str] = None,
  33. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  34. key_file: Optional[str] = None,
  35. allow_variable_data_keys: bool = False,
  36. output_dir: Optional[str] = None,
  37. dtype: str = "float32",
  38. seed: int = 0,
  39. num_workers: int = 1,
  40. **kwargs,
  41. ):
  42. if batch_size > 1:
  43. raise NotImplementedError("batch decoding is not implemented")
  44. logging.basicConfig(
  45. level=log_level,
  46. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  47. )
  48. if ngpu >= 1 and torch.cuda.is_available():
  49. device = "cuda"
  50. else:
  51. device = "cpu"
  52. batch_size = 1
  53. # 1. Set random-seed
  54. set_all_random_seed(seed)
  55. # 2. Build speech2vadsegment
  56. speech2vadsegment_kwargs = dict(
  57. vad_infer_config=vad_infer_config,
  58. vad_model_file=vad_model_file,
  59. vad_cmvn_file=vad_cmvn_file,
  60. device=device,
  61. dtype=dtype,
  62. )
  63. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  64. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  65. def _forward(
  66. data_path_and_name_and_type,
  67. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  68. output_dir_v2: Optional[str] = None,
  69. fs: dict = None,
  70. param_dict: dict = None
  71. ):
  72. # 3. Build data-iterator
  73. if data_path_and_name_and_type is None and raw_inputs is not None:
  74. if isinstance(raw_inputs, torch.Tensor):
  75. raw_inputs = raw_inputs.numpy()
  76. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  77. loader = build_streaming_iterator(
  78. task_name="vad",
  79. preprocess_args=None,
  80. data_path_and_name_and_type=data_path_and_name_and_type,
  81. dtype=dtype,
  82. batch_size=batch_size,
  83. key_file=key_file,
  84. num_workers=num_workers,
  85. )
  86. finish_count = 0
  87. file_count = 1
  88. # 7 .Start for-loop
  89. # FIXME(kamo): The output format should be discussed about
  90. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  91. if output_path is not None:
  92. writer = DatadirWriter(output_path)
  93. ibest_writer = writer[f"1best_recog"]
  94. else:
  95. writer = None
  96. ibest_writer = None
  97. vad_results = []
  98. for keys, batch in loader:
  99. assert isinstance(batch, dict), type(batch)
  100. assert all(isinstance(s, str) for s in keys), keys
  101. _bs = len(next(iter(batch.values())))
  102. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  103. # do vad segment
  104. _, results = speech2vadsegment(**batch)
  105. for i, _ in enumerate(keys):
  106. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  107. results[i] = json.dumps(results[i])
  108. item = {'key': keys[i], 'value': results[i]}
  109. vad_results.append(item)
  110. if writer is not None:
  111. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  112. return vad_results
  113. return _forward
  114. def inference_vad_online(
  115. batch_size: int,
  116. ngpu: int,
  117. log_level: Union[int, str],
  118. # data_path_and_name_and_type,
  119. vad_infer_config: Optional[str],
  120. vad_model_file: Optional[str],
  121. vad_cmvn_file: Optional[str] = None,
  122. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  123. key_file: Optional[str] = None,
  124. allow_variable_data_keys: bool = False,
  125. output_dir: Optional[str] = None,
  126. dtype: str = "float32",
  127. seed: int = 0,
  128. num_workers: int = 1,
  129. **kwargs,
  130. ):
  131. logging.basicConfig(
  132. level=log_level,
  133. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  134. )
  135. if ngpu >= 1 and torch.cuda.is_available():
  136. device = "cuda"
  137. else:
  138. device = "cpu"
  139. batch_size = 1
  140. # 1. Set random-seed
  141. set_all_random_seed(seed)
  142. # 2. Build speech2vadsegment
  143. speech2vadsegment_kwargs = dict(
  144. vad_infer_config=vad_infer_config,
  145. vad_model_file=vad_model_file,
  146. vad_cmvn_file=vad_cmvn_file,
  147. device=device,
  148. dtype=dtype,
  149. )
  150. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  151. speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
  152. def _forward(
  153. data_path_and_name_and_type,
  154. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  155. output_dir_v2: Optional[str] = None,
  156. fs: dict = None,
  157. param_dict: dict = None,
  158. ):
  159. # 3. Build data-iterator
  160. if data_path_and_name_and_type is None and raw_inputs is not None:
  161. if isinstance(raw_inputs, torch.Tensor):
  162. raw_inputs = raw_inputs.numpy()
  163. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  164. loader = build_streaming_iterator(
  165. task_name="vad",
  166. preprocess_args=None,
  167. data_path_and_name_and_type=data_path_and_name_and_type,
  168. dtype=dtype,
  169. batch_size=batch_size,
  170. key_file=key_file,
  171. num_workers=num_workers,
  172. )
  173. finish_count = 0
  174. file_count = 1
  175. # 7 .Start for-loop
  176. # FIXME(kamo): The output format should be discussed about
  177. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  178. if output_path is not None:
  179. writer = DatadirWriter(output_path)
  180. ibest_writer = writer[f"1best_recog"]
  181. else:
  182. writer = None
  183. ibest_writer = None
  184. vad_results = []
  185. if param_dict is None:
  186. param_dict = dict()
  187. param_dict['in_cache'] = dict()
  188. param_dict['is_final'] = True
  189. batch_in_cache = param_dict.get('in_cache', dict())
  190. is_final = param_dict.get('is_final', False)
  191. max_end_sil = param_dict.get('max_end_sil', 800)
  192. for keys, batch in loader:
  193. assert isinstance(batch, dict), type(batch)
  194. assert all(isinstance(s, str) for s in keys), keys
  195. _bs = len(next(iter(batch.values())))
  196. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  197. batch['in_cache'] = batch_in_cache
  198. batch['is_final'] = is_final
  199. batch['max_end_sil'] = max_end_sil
  200. # do vad segment
  201. _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
  202. # param_dict['in_cache'] = batch['in_cache']
  203. if results:
  204. for i, _ in enumerate(keys):
  205. if results[i]:
  206. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  207. results[i] = json.dumps(results[i])
  208. item = {'key': keys[i], 'value': results[i]}
  209. vad_results.append(item)
  210. if writer is not None:
  211. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  212. return vad_results
  213. return _forward
  214. def inference_launch(mode, **kwargs):
  215. if mode == "offline":
  216. return inference_vad(**kwargs)
  217. elif mode == "online":
  218. return inference_vad_online(**kwargs)
  219. else:
  220. logging.info("Unknown decoding mode: {}".format(mode))
  221. return None
  222. def get_parser():
  223. parser = config_argparse.ArgumentParser(
  224. description="VAD Decoding",
  225. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  226. )
  227. # Note(kamo): Use '_' instead of '-' as separator.
  228. # '-' is confusing if written in yaml.
  229. parser.add_argument(
  230. "--log_level",
  231. type=lambda x: x.upper(),
  232. default="INFO",
  233. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  234. help="The verbose level of logging",
  235. )
  236. parser.add_argument("--output_dir", type=str, required=True)
  237. parser.add_argument(
  238. "--ngpu",
  239. type=int,
  240. default=0,
  241. help="The number of gpus. 0 indicates CPU mode",
  242. )
  243. parser.add_argument(
  244. "--njob",
  245. type=int,
  246. default=1,
  247. help="The number of jobs for each gpu",
  248. )
  249. parser.add_argument(
  250. "--gpuid_list",
  251. type=str,
  252. default="",
  253. help="The visible gpus",
  254. )
  255. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  256. parser.add_argument(
  257. "--dtype",
  258. default="float32",
  259. choices=["float16", "float32", "float64"],
  260. help="Data type",
  261. )
  262. parser.add_argument(
  263. "--num_workers",
  264. type=int,
  265. default=1,
  266. help="The number of workers used for DataLoader",
  267. )
  268. group = parser.add_argument_group("Input data related")
  269. group.add_argument(
  270. "--data_path_and_name_and_type",
  271. type=str2triple_str,
  272. required=True,
  273. action="append",
  274. )
  275. group.add_argument("--key_file", type=str_or_none)
  276. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  277. group = parser.add_argument_group("The model configuration related")
  278. group.add_argument(
  279. "--vad_infer_config",
  280. type=str,
  281. help="VAD infer configuration",
  282. )
  283. group.add_argument(
  284. "--vad_model_file",
  285. type=str,
  286. help="VAD model parameter file",
  287. )
  288. group.add_argument(
  289. "--vad_cmvn_file",
  290. type=str,
  291. help="Global CMVN file",
  292. )
  293. group.add_argument(
  294. "--vad_train_config",
  295. type=str,
  296. help="VAD training configuration",
  297. )
  298. group = parser.add_argument_group("The inference configuration related")
  299. group.add_argument(
  300. "--batch_size",
  301. type=int,
  302. default=1,
  303. help="The batch size for inference",
  304. )
  305. return parser
  306. def main(cmd=None):
  307. print(get_commandline_args(), file=sys.stderr)
  308. parser = get_parser()
  309. parser.add_argument(
  310. "--mode",
  311. type=str,
  312. default="vad",
  313. help="The decoding mode",
  314. )
  315. args = parser.parse_args(cmd)
  316. kwargs = vars(args)
  317. kwargs.pop("config", None)
  318. # set logging messages
  319. logging.basicConfig(
  320. level=args.log_level,
  321. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  322. )
  323. logging.info("Decoding args: {}".format(kwargs))
  324. # gpu setting
  325. if args.ngpu > 0:
  326. jobid = int(args.output_dir.split(".")[-1])
  327. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  328. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  329. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  330. inference_pipeline = inference_launch(**kwargs)
  331. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  332. if __name__ == "__main__":
  333. main()