vad_inference_launch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. torch.set_num_threads(1)
  7. import argparse
  8. import logging
  9. import os
  10. import sys
  11. import json
  12. from typing import Optional
  13. from typing import Union
  14. import numpy as np
  15. import torch
  16. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  17. from funasr.fileio.datadir_writer import DatadirWriter
  18. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  19. from funasr.utils import config_argparse
  20. from funasr.utils.cli_utils import get_commandline_args
  21. from funasr.utils.types import str2bool
  22. from funasr.utils.types import str2triple_str
  23. from funasr.utils.types import str_or_none
  24. from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
  25. def inference_vad(
  26. batch_size: int,
  27. ngpu: int,
  28. log_level: Union[int, str],
  29. # data_path_and_name_and_type,
  30. vad_infer_config: Optional[str],
  31. vad_model_file: Optional[str],
  32. vad_cmvn_file: Optional[str] = None,
  33. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  34. key_file: Optional[str] = None,
  35. allow_variable_data_keys: bool = False,
  36. output_dir: Optional[str] = None,
  37. dtype: str = "float32",
  38. seed: int = 0,
  39. num_workers: int = 1,
  40. **kwargs,
  41. ):
  42. if batch_size > 1:
  43. raise NotImplementedError("batch decoding is not implemented")
  44. logging.basicConfig(
  45. level=log_level,
  46. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  47. )
  48. if ngpu >= 1 and torch.cuda.is_available():
  49. device = "cuda"
  50. else:
  51. device = "cpu"
  52. batch_size = 1
  53. # 1. Set random-seed
  54. set_all_random_seed(seed)
  55. # 2. Build speech2vadsegment
  56. speech2vadsegment_kwargs = dict(
  57. vad_infer_config=vad_infer_config,
  58. vad_model_file=vad_model_file,
  59. vad_cmvn_file=vad_cmvn_file,
  60. device=device,
  61. dtype=dtype,
  62. )
  63. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  64. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  65. def _forward(
  66. data_path_and_name_and_type,
  67. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  68. output_dir_v2: Optional[str] = None,
  69. fs: dict = None,
  70. param_dict: dict = None
  71. ):
  72. # 3. Build data-iterator
  73. if data_path_and_name_and_type is None and raw_inputs is not None:
  74. if isinstance(raw_inputs, torch.Tensor):
  75. raw_inputs = raw_inputs.numpy()
  76. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  77. loader = build_streaming_iterator(
  78. task_name="vad",
  79. preprocess_args=None,
  80. data_path_and_name_and_type=data_path_and_name_and_type,
  81. dtype=dtype,
  82. fs=fs,
  83. batch_size=batch_size,
  84. key_file=key_file,
  85. num_workers=num_workers,
  86. )
  87. finish_count = 0
  88. file_count = 1
  89. # 7 .Start for-loop
  90. # FIXME(kamo): The output format should be discussed about
  91. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  92. if output_path is not None:
  93. writer = DatadirWriter(output_path)
  94. ibest_writer = writer[f"1best_recog"]
  95. else:
  96. writer = None
  97. ibest_writer = None
  98. vad_results = []
  99. for keys, batch in loader:
  100. assert isinstance(batch, dict), type(batch)
  101. assert all(isinstance(s, str) for s in keys), keys
  102. _bs = len(next(iter(batch.values())))
  103. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  104. # do vad segment
  105. _, results = speech2vadsegment(**batch)
  106. for i, _ in enumerate(keys):
  107. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  108. results[i] = json.dumps(results[i])
  109. item = {'key': keys[i], 'value': results[i]}
  110. vad_results.append(item)
  111. if writer is not None:
  112. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  113. torch.cuda.empty_cache()
  114. return vad_results
  115. return _forward
  116. def inference_vad_online(
  117. batch_size: int,
  118. ngpu: int,
  119. log_level: Union[int, str],
  120. # data_path_and_name_and_type,
  121. vad_infer_config: Optional[str],
  122. vad_model_file: Optional[str],
  123. vad_cmvn_file: Optional[str] = None,
  124. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  125. key_file: Optional[str] = None,
  126. allow_variable_data_keys: bool = False,
  127. output_dir: Optional[str] = None,
  128. dtype: str = "float32",
  129. seed: int = 0,
  130. num_workers: int = 1,
  131. **kwargs,
  132. ):
  133. logging.basicConfig(
  134. level=log_level,
  135. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  136. )
  137. if ngpu >= 1 and torch.cuda.is_available():
  138. device = "cuda"
  139. else:
  140. device = "cpu"
  141. batch_size = 1
  142. # 1. Set random-seed
  143. set_all_random_seed(seed)
  144. # 2. Build speech2vadsegment
  145. speech2vadsegment_kwargs = dict(
  146. vad_infer_config=vad_infer_config,
  147. vad_model_file=vad_model_file,
  148. vad_cmvn_file=vad_cmvn_file,
  149. device=device,
  150. dtype=dtype,
  151. )
  152. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  153. speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
  154. def _forward(
  155. data_path_and_name_and_type,
  156. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  157. output_dir_v2: Optional[str] = None,
  158. fs: dict = None,
  159. param_dict: dict = None,
  160. ):
  161. # 3. Build data-iterator
  162. if data_path_and_name_and_type is None and raw_inputs is not None:
  163. if isinstance(raw_inputs, torch.Tensor):
  164. raw_inputs = raw_inputs.numpy()
  165. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  166. loader = build_streaming_iterator(
  167. task_name="vad",
  168. preprocess_args=None,
  169. data_path_and_name_and_type=data_path_and_name_and_type,
  170. dtype=dtype,
  171. fs=fs,
  172. batch_size=batch_size,
  173. key_file=key_file,
  174. num_workers=num_workers,
  175. )
  176. finish_count = 0
  177. file_count = 1
  178. # 7 .Start for-loop
  179. # FIXME(kamo): The output format should be discussed about
  180. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  181. if output_path is not None:
  182. writer = DatadirWriter(output_path)
  183. ibest_writer = writer[f"1best_recog"]
  184. else:
  185. writer = None
  186. ibest_writer = None
  187. vad_results = []
  188. if param_dict is None:
  189. param_dict = dict()
  190. param_dict['in_cache'] = dict()
  191. param_dict['is_final'] = True
  192. batch_in_cache = param_dict.get('in_cache', dict())
  193. is_final = param_dict.get('is_final', False)
  194. max_end_sil = param_dict.get('max_end_sil', 800)
  195. for keys, batch in loader:
  196. assert isinstance(batch, dict), type(batch)
  197. assert all(isinstance(s, str) for s in keys), keys
  198. _bs = len(next(iter(batch.values())))
  199. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  200. batch['in_cache'] = batch_in_cache
  201. batch['is_final'] = is_final
  202. batch['max_end_sil'] = max_end_sil
  203. # do vad segment
  204. _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
  205. # param_dict['in_cache'] = batch['in_cache']
  206. if results:
  207. for i, _ in enumerate(keys):
  208. if results[i]:
  209. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  210. results[i] = json.dumps(results[i])
  211. item = {'key': keys[i], 'value': results[i]}
  212. vad_results.append(item)
  213. if writer is not None:
  214. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  215. return vad_results
  216. return _forward
  217. def inference_launch(mode, **kwargs):
  218. if mode == "offline":
  219. return inference_vad(**kwargs)
  220. elif mode == "online":
  221. return inference_vad_online(**kwargs)
  222. else:
  223. logging.info("Unknown decoding mode: {}".format(mode))
  224. return None
  225. def get_parser():
  226. parser = config_argparse.ArgumentParser(
  227. description="VAD Decoding",
  228. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  229. )
  230. # Note(kamo): Use '_' instead of '-' as separator.
  231. # '-' is confusing if written in yaml.
  232. parser.add_argument(
  233. "--log_level",
  234. type=lambda x: x.upper(),
  235. default="INFO",
  236. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  237. help="The verbose level of logging",
  238. )
  239. parser.add_argument("--output_dir", type=str, required=True)
  240. parser.add_argument(
  241. "--ngpu",
  242. type=int,
  243. default=0,
  244. help="The number of gpus. 0 indicates CPU mode",
  245. )
  246. parser.add_argument(
  247. "--njob",
  248. type=int,
  249. default=1,
  250. help="The number of jobs for each gpu",
  251. )
  252. parser.add_argument(
  253. "--gpuid_list",
  254. type=str,
  255. default="",
  256. help="The visible gpus",
  257. )
  258. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  259. parser.add_argument(
  260. "--dtype",
  261. default="float32",
  262. choices=["float16", "float32", "float64"],
  263. help="Data type",
  264. )
  265. parser.add_argument(
  266. "--num_workers",
  267. type=int,
  268. default=1,
  269. help="The number of workers used for DataLoader",
  270. )
  271. group = parser.add_argument_group("Input data related")
  272. group.add_argument(
  273. "--data_path_and_name_and_type",
  274. type=str2triple_str,
  275. required=True,
  276. action="append",
  277. )
  278. group.add_argument("--key_file", type=str_or_none)
  279. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  280. group = parser.add_argument_group("The model configuration related")
  281. group.add_argument(
  282. "--vad_infer_config",
  283. type=str,
  284. help="VAD infer configuration",
  285. )
  286. group.add_argument(
  287. "--vad_model_file",
  288. type=str,
  289. help="VAD model parameter file",
  290. )
  291. group.add_argument(
  292. "--vad_cmvn_file",
  293. type=str,
  294. help="Global CMVN file",
  295. )
  296. group.add_argument(
  297. "--vad_train_config",
  298. type=str,
  299. help="VAD training configuration",
  300. )
  301. group = parser.add_argument_group("The inference configuration related")
  302. group.add_argument(
  303. "--batch_size",
  304. type=int,
  305. default=1,
  306. help="The batch size for inference",
  307. )
  308. return parser
  309. def main(cmd=None):
  310. print(get_commandline_args(), file=sys.stderr)
  311. parser = get_parser()
  312. parser.add_argument(
  313. "--mode",
  314. type=str,
  315. default="vad",
  316. help="The decoding mode",
  317. )
  318. args = parser.parse_args(cmd)
  319. kwargs = vars(args)
  320. kwargs.pop("config", None)
  321. # set logging messages
  322. logging.basicConfig(
  323. level=args.log_level,
  324. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  325. )
  326. logging.info("Decoding args: {}".format(kwargs))
  327. # gpu setting
  328. if args.ngpu > 0:
  329. jobid = int(args.output_dir.split(".")[-1])
  330. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  331. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  332. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  333. inference_pipeline = inference_launch(**kwargs)
  334. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  335. if __name__ == "__main__":
  336. main()