vad_inference_launch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. torch.set_num_threads(1)
  7. import argparse
  8. import logging
  9. import os
  10. import sys
  11. import json
  12. from typing import Optional
  13. from typing import Union
  14. import numpy as np
  15. import torch
  16. from typeguard import check_argument_types
  17. from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
  18. from funasr.fileio.datadir_writer import DatadirWriter
  19. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  20. from funasr.utils import config_argparse
  21. from funasr.utils.cli_utils import get_commandline_args
  22. from funasr.utils.types import str2bool
  23. from funasr.utils.types import str2triple_str
  24. from funasr.utils.types import str_or_none
  25. from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
  26. def inference_vad(
  27. batch_size: int,
  28. ngpu: int,
  29. log_level: Union[int, str],
  30. # data_path_and_name_and_type,
  31. vad_infer_config: Optional[str],
  32. vad_model_file: Optional[str],
  33. vad_cmvn_file: Optional[str] = None,
  34. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  35. key_file: Optional[str] = None,
  36. allow_variable_data_keys: bool = False,
  37. output_dir: Optional[str] = None,
  38. dtype: str = "float32",
  39. seed: int = 0,
  40. num_workers: int = 1,
  41. **kwargs,
  42. ):
  43. assert check_argument_types()
  44. if batch_size > 1:
  45. raise NotImplementedError("batch decoding is not implemented")
  46. logging.basicConfig(
  47. level=log_level,
  48. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  49. )
  50. if ngpu >= 1 and torch.cuda.is_available():
  51. device = "cuda"
  52. else:
  53. device = "cpu"
  54. batch_size = 1
  55. # 1. Set random-seed
  56. set_all_random_seed(seed)
  57. # 2. Build speech2vadsegment
  58. speech2vadsegment_kwargs = dict(
  59. vad_infer_config=vad_infer_config,
  60. vad_model_file=vad_model_file,
  61. vad_cmvn_file=vad_cmvn_file,
  62. device=device,
  63. dtype=dtype,
  64. )
  65. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  66. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  67. def _forward(
  68. data_path_and_name_and_type,
  69. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  70. output_dir_v2: Optional[str] = None,
  71. fs: dict = None,
  72. param_dict: dict = None
  73. ):
  74. # 3. Build data-iterator
  75. if data_path_and_name_and_type is None and raw_inputs is not None:
  76. if isinstance(raw_inputs, torch.Tensor):
  77. raw_inputs = raw_inputs.numpy()
  78. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  79. loader = build_streaming_iterator(
  80. task_name="vad",
  81. preprocess_args=None,
  82. data_path_and_name_and_type=data_path_and_name_and_type,
  83. dtype=dtype,
  84. batch_size=batch_size,
  85. key_file=key_file,
  86. num_workers=num_workers,
  87. )
  88. finish_count = 0
  89. file_count = 1
  90. # 7 .Start for-loop
  91. # FIXME(kamo): The output format should be discussed about
  92. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  93. if output_path is not None:
  94. writer = DatadirWriter(output_path)
  95. ibest_writer = writer[f"1best_recog"]
  96. else:
  97. writer = None
  98. ibest_writer = None
  99. vad_results = []
  100. for keys, batch in loader:
  101. assert isinstance(batch, dict), type(batch)
  102. assert all(isinstance(s, str) for s in keys), keys
  103. _bs = len(next(iter(batch.values())))
  104. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  105. # do vad segment
  106. _, results = speech2vadsegment(**batch)
  107. for i, _ in enumerate(keys):
  108. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  109. results[i] = json.dumps(results[i])
  110. item = {'key': keys[i], 'value': results[i]}
  111. vad_results.append(item)
  112. if writer is not None:
  113. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  114. return vad_results
  115. return _forward
  116. def inference_vad_online(
  117. batch_size: int,
  118. ngpu: int,
  119. log_level: Union[int, str],
  120. # data_path_and_name_and_type,
  121. vad_infer_config: Optional[str],
  122. vad_model_file: Optional[str],
  123. vad_cmvn_file: Optional[str] = None,
  124. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  125. key_file: Optional[str] = None,
  126. allow_variable_data_keys: bool = False,
  127. output_dir: Optional[str] = None,
  128. dtype: str = "float32",
  129. seed: int = 0,
  130. num_workers: int = 1,
  131. **kwargs,
  132. ):
  133. assert check_argument_types()
  134. logging.basicConfig(
  135. level=log_level,
  136. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  137. )
  138. if ngpu >= 1 and torch.cuda.is_available():
  139. device = "cuda"
  140. else:
  141. device = "cpu"
  142. batch_size = 1
  143. # 1. Set random-seed
  144. set_all_random_seed(seed)
  145. # 2. Build speech2vadsegment
  146. speech2vadsegment_kwargs = dict(
  147. vad_infer_config=vad_infer_config,
  148. vad_model_file=vad_model_file,
  149. vad_cmvn_file=vad_cmvn_file,
  150. device=device,
  151. dtype=dtype,
  152. )
  153. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  154. speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
  155. def _forward(
  156. data_path_and_name_and_type,
  157. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  158. output_dir_v2: Optional[str] = None,
  159. fs: dict = None,
  160. param_dict: dict = None,
  161. ):
  162. # 3. Build data-iterator
  163. if data_path_and_name_and_type is None and raw_inputs is not None:
  164. if isinstance(raw_inputs, torch.Tensor):
  165. raw_inputs = raw_inputs.numpy()
  166. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  167. loader = build_streaming_iterator(
  168. task_name="vad",
  169. preprocess_args=None,
  170. data_path_and_name_and_type=data_path_and_name_and_type,
  171. dtype=dtype,
  172. batch_size=batch_size,
  173. key_file=key_file,
  174. num_workers=num_workers,
  175. )
  176. finish_count = 0
  177. file_count = 1
  178. # 7 .Start for-loop
  179. # FIXME(kamo): The output format should be discussed about
  180. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  181. if output_path is not None:
  182. writer = DatadirWriter(output_path)
  183. ibest_writer = writer[f"1best_recog"]
  184. else:
  185. writer = None
  186. ibest_writer = None
  187. vad_results = []
  188. if param_dict is None:
  189. param_dict = dict()
  190. param_dict['in_cache'] = dict()
  191. param_dict['is_final'] = True
  192. batch_in_cache = param_dict.get('in_cache', dict())
  193. is_final = param_dict.get('is_final', False)
  194. max_end_sil = param_dict.get('max_end_sil', 800)
  195. for keys, batch in loader:
  196. assert isinstance(batch, dict), type(batch)
  197. assert all(isinstance(s, str) for s in keys), keys
  198. _bs = len(next(iter(batch.values())))
  199. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  200. batch['in_cache'] = batch_in_cache
  201. batch['is_final'] = is_final
  202. batch['max_end_sil'] = max_end_sil
  203. # do vad segment
  204. _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
  205. # param_dict['in_cache'] = batch['in_cache']
  206. if results:
  207. for i, _ in enumerate(keys):
  208. if results[i]:
  209. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  210. results[i] = json.dumps(results[i])
  211. item = {'key': keys[i], 'value': results[i]}
  212. vad_results.append(item)
  213. if writer is not None:
  214. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  215. return vad_results
  216. return _forward
  217. def inference_launch(mode, **kwargs):
  218. if mode == "offline":
  219. return inference_vad(**kwargs)
  220. elif mode == "online":
  221. return inference_vad_online(**kwargs)
  222. else:
  223. logging.info("Unknown decoding mode: {}".format(mode))
  224. return None
  225. def get_parser():
  226. parser = config_argparse.ArgumentParser(
  227. description="VAD Decoding",
  228. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  229. )
  230. # Note(kamo): Use '_' instead of '-' as separator.
  231. # '-' is confusing if written in yaml.
  232. parser.add_argument(
  233. "--log_level",
  234. type=lambda x: x.upper(),
  235. default="INFO",
  236. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  237. help="The verbose level of logging",
  238. )
  239. parser.add_argument("--output_dir", type=str, required=True)
  240. parser.add_argument(
  241. "--ngpu",
  242. type=int,
  243. default=0,
  244. help="The number of gpus. 0 indicates CPU mode",
  245. )
  246. parser.add_argument(
  247. "--njob",
  248. type=int,
  249. default=1,
  250. help="The number of jobs for each gpu",
  251. )
  252. parser.add_argument(
  253. "--gpuid_list",
  254. type=str,
  255. default="",
  256. help="The visible gpus",
  257. )
  258. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  259. parser.add_argument(
  260. "--dtype",
  261. default="float32",
  262. choices=["float16", "float32", "float64"],
  263. help="Data type",
  264. )
  265. parser.add_argument(
  266. "--num_workers",
  267. type=int,
  268. default=1,
  269. help="The number of workers used for DataLoader",
  270. )
  271. group = parser.add_argument_group("Input data related")
  272. group.add_argument(
  273. "--data_path_and_name_and_type",
  274. type=str2triple_str,
  275. required=True,
  276. action="append",
  277. )
  278. group.add_argument("--key_file", type=str_or_none)
  279. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  280. group = parser.add_argument_group("The model configuration related")
  281. group.add_argument(
  282. "--vad_infer_config",
  283. type=str,
  284. help="VAD infer configuration",
  285. )
  286. group.add_argument(
  287. "--vad_model_file",
  288. type=str,
  289. help="VAD model parameter file",
  290. )
  291. group.add_argument(
  292. "--vad_cmvn_file",
  293. type=str,
  294. help="Global CMVN file",
  295. )
  296. group.add_argument(
  297. "--vad_train_config",
  298. type=str,
  299. help="VAD training configuration",
  300. )
  301. group = parser.add_argument_group("The inference configuration related")
  302. group.add_argument(
  303. "--batch_size",
  304. type=int,
  305. default=1,
  306. help="The batch size for inference",
  307. )
  308. return parser
  309. def main(cmd=None):
  310. print(get_commandline_args(), file=sys.stderr)
  311. parser = get_parser()
  312. parser.add_argument(
  313. "--mode",
  314. type=str,
  315. default="vad",
  316. help="The decoding mode",
  317. )
  318. args = parser.parse_args(cmd)
  319. kwargs = vars(args)
  320. kwargs.pop("config", None)
  321. # set logging messages
  322. logging.basicConfig(
  323. level=args.log_level,
  324. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  325. )
  326. logging.info("Decoding args: {}".format(kwargs))
  327. # gpu setting
  328. if args.ngpu > 0:
  329. jobid = int(args.output_dir.split(".")[-1])
  330. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  331. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  332. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  333. inference_pipeline = inference_launch(**kwargs)
  334. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  335. if __name__ == "__main__":
  336. main()