vad_inference_launch.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. torch.set_num_threads(1)
  7. import argparse
  8. import logging
  9. import os
  10. import sys
  11. from typing import Union, Dict, Any
  12. from funasr.utils import config_argparse
  13. from funasr.utils.cli_utils import get_commandline_args
  14. from funasr.utils.types import str2bool
  15. from funasr.utils.types import str2triple_str
  16. from funasr.utils.types import str_or_none
  17. import argparse
  18. import logging
  19. import os
  20. import sys
  21. import json
  22. from pathlib import Path
  23. from typing import Any
  24. from typing import List
  25. from typing import Optional
  26. from typing import Sequence
  27. from typing import Tuple
  28. from typing import Union
  29. from typing import Dict
  30. import math
  31. import numpy as np
  32. import torch
  33. from typeguard import check_argument_types
  34. from typeguard import check_return_type
  35. from funasr.fileio.datadir_writer import DatadirWriter
  36. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  37. from funasr.modules.subsampling import TooShortUttError
  38. from funasr.tasks.vad import VADTask
  39. from funasr.torch_utils.device_funcs import to_device
  40. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  41. from funasr.utils import config_argparse
  42. from funasr.utils.cli_utils import get_commandline_args
  43. from funasr.utils.types import str2bool
  44. from funasr.utils.types import str2triple_str
  45. from funasr.utils.types import str_or_none
  46. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  47. from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
  48. from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
  49. def inference_vad(
  50. batch_size: int,
  51. ngpu: int,
  52. log_level: Union[int, str],
  53. # data_path_and_name_and_type,
  54. vad_infer_config: Optional[str],
  55. vad_model_file: Optional[str],
  56. vad_cmvn_file: Optional[str] = None,
  57. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  58. key_file: Optional[str] = None,
  59. allow_variable_data_keys: bool = False,
  60. output_dir: Optional[str] = None,
  61. dtype: str = "float32",
  62. seed: int = 0,
  63. num_workers: int = 1,
  64. **kwargs,
  65. ):
  66. assert check_argument_types()
  67. if batch_size > 1:
  68. raise NotImplementedError("batch decoding is not implemented")
  69. logging.basicConfig(
  70. level=log_level,
  71. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  72. )
  73. if ngpu >= 1 and torch.cuda.is_available():
  74. device = "cuda"
  75. else:
  76. device = "cpu"
  77. batch_size = 1
  78. # 1. Set random-seed
  79. set_all_random_seed(seed)
  80. # 2. Build speech2vadsegment
  81. speech2vadsegment_kwargs = dict(
  82. vad_infer_config=vad_infer_config,
  83. vad_model_file=vad_model_file,
  84. vad_cmvn_file=vad_cmvn_file,
  85. device=device,
  86. dtype=dtype,
  87. )
  88. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  89. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  90. def _forward(
  91. data_path_and_name_and_type,
  92. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  93. output_dir_v2: Optional[str] = None,
  94. fs: dict = None,
  95. param_dict: dict = None
  96. ):
  97. # 3. Build data-iterator
  98. if data_path_and_name_and_type is None and raw_inputs is not None:
  99. if isinstance(raw_inputs, torch.Tensor):
  100. raw_inputs = raw_inputs.numpy()
  101. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  102. loader = VADTask.build_streaming_iterator(
  103. data_path_and_name_and_type,
  104. dtype=dtype,
  105. batch_size=batch_size,
  106. key_file=key_file,
  107. num_workers=num_workers,
  108. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  109. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  110. allow_variable_data_keys=allow_variable_data_keys,
  111. inference=True,
  112. )
  113. finish_count = 0
  114. file_count = 1
  115. # 7 .Start for-loop
  116. # FIXME(kamo): The output format should be discussed about
  117. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  118. if output_path is not None:
  119. writer = DatadirWriter(output_path)
  120. ibest_writer = writer[f"1best_recog"]
  121. else:
  122. writer = None
  123. ibest_writer = None
  124. vad_results = []
  125. for keys, batch in loader:
  126. assert isinstance(batch, dict), type(batch)
  127. assert all(isinstance(s, str) for s in keys), keys
  128. _bs = len(next(iter(batch.values())))
  129. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  130. # do vad segment
  131. _, results = speech2vadsegment(**batch)
  132. for i, _ in enumerate(keys):
  133. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  134. results[i] = json.dumps(results[i])
  135. item = {'key': keys[i], 'value': results[i]}
  136. vad_results.append(item)
  137. if writer is not None:
  138. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  139. return vad_results
  140. return _forward
  141. def inference_vad_online(
  142. batch_size: int,
  143. ngpu: int,
  144. log_level: Union[int, str],
  145. # data_path_and_name_and_type,
  146. vad_infer_config: Optional[str],
  147. vad_model_file: Optional[str],
  148. vad_cmvn_file: Optional[str] = None,
  149. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  150. key_file: Optional[str] = None,
  151. allow_variable_data_keys: bool = False,
  152. output_dir: Optional[str] = None,
  153. dtype: str = "float32",
  154. seed: int = 0,
  155. num_workers: int = 1,
  156. **kwargs,
  157. ):
  158. assert check_argument_types()
  159. logging.basicConfig(
  160. level=log_level,
  161. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  162. )
  163. if ngpu >= 1 and torch.cuda.is_available():
  164. device = "cuda"
  165. else:
  166. device = "cpu"
  167. batch_size = 1
  168. # 1. Set random-seed
  169. set_all_random_seed(seed)
  170. # 2. Build speech2vadsegment
  171. speech2vadsegment_kwargs = dict(
  172. vad_infer_config=vad_infer_config,
  173. vad_model_file=vad_model_file,
  174. vad_cmvn_file=vad_cmvn_file,
  175. device=device,
  176. dtype=dtype,
  177. )
  178. logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  179. speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
  180. def _forward(
  181. data_path_and_name_and_type,
  182. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  183. output_dir_v2: Optional[str] = None,
  184. fs: dict = None,
  185. param_dict: dict = None,
  186. ):
  187. # 3. Build data-iterator
  188. if data_path_and_name_and_type is None and raw_inputs is not None:
  189. if isinstance(raw_inputs, torch.Tensor):
  190. raw_inputs = raw_inputs.numpy()
  191. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  192. loader = VADTask.build_streaming_iterator(
  193. data_path_and_name_and_type,
  194. dtype=dtype,
  195. batch_size=batch_size,
  196. key_file=key_file,
  197. num_workers=num_workers,
  198. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  199. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  200. allow_variable_data_keys=allow_variable_data_keys,
  201. inference=True,
  202. )
  203. finish_count = 0
  204. file_count = 1
  205. # 7 .Start for-loop
  206. # FIXME(kamo): The output format should be discussed about
  207. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  208. if output_path is not None:
  209. writer = DatadirWriter(output_path)
  210. ibest_writer = writer[f"1best_recog"]
  211. else:
  212. writer = None
  213. ibest_writer = None
  214. vad_results = []
  215. if param_dict is None:
  216. param_dict = dict()
  217. param_dict['in_cache'] = dict()
  218. param_dict['is_final'] = True
  219. batch_in_cache = param_dict.get('in_cache', dict())
  220. is_final = param_dict.get('is_final', False)
  221. max_end_sil = param_dict.get('max_end_sil', 800)
  222. for keys, batch in loader:
  223. assert isinstance(batch, dict), type(batch)
  224. assert all(isinstance(s, str) for s in keys), keys
  225. _bs = len(next(iter(batch.values())))
  226. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  227. batch['in_cache'] = batch_in_cache
  228. batch['is_final'] = is_final
  229. batch['max_end_sil'] = max_end_sil
  230. # do vad segment
  231. _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
  232. # param_dict['in_cache'] = batch['in_cache']
  233. if results:
  234. for i, _ in enumerate(keys):
  235. if results[i]:
  236. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  237. results[i] = json.dumps(results[i])
  238. item = {'key': keys[i], 'value': results[i]}
  239. vad_results.append(item)
  240. if writer is not None:
  241. ibest_writer["text"][keys[i]] = "{}".format(results[i])
  242. return vad_results
  243. return _forward
  244. def inference_launch(mode, **kwargs):
  245. if mode == "offline":
  246. return inference_vad(**kwargs)
  247. elif mode == "online":
  248. return inference_vad_online(**kwargs)
  249. else:
  250. logging.info("Unknown decoding mode: {}".format(mode))
  251. return None
  252. def get_parser():
  253. parser = config_argparse.ArgumentParser(
  254. description="VAD Decoding",
  255. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  256. )
  257. # Note(kamo): Use '_' instead of '-' as separator.
  258. # '-' is confusing if written in yaml.
  259. parser.add_argument(
  260. "--log_level",
  261. type=lambda x: x.upper(),
  262. default="INFO",
  263. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  264. help="The verbose level of logging",
  265. )
  266. parser.add_argument("--output_dir", type=str, required=True)
  267. parser.add_argument(
  268. "--ngpu",
  269. type=int,
  270. default=0,
  271. help="The number of gpus. 0 indicates CPU mode",
  272. )
  273. parser.add_argument(
  274. "--njob",
  275. type=int,
  276. default=1,
  277. help="The number of jobs for each gpu",
  278. )
  279. parser.add_argument(
  280. "--gpuid_list",
  281. type=str,
  282. default="",
  283. help="The visible gpus",
  284. )
  285. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  286. parser.add_argument(
  287. "--dtype",
  288. default="float32",
  289. choices=["float16", "float32", "float64"],
  290. help="Data type",
  291. )
  292. parser.add_argument(
  293. "--num_workers",
  294. type=int,
  295. default=1,
  296. help="The number of workers used for DataLoader",
  297. )
  298. group = parser.add_argument_group("Input data related")
  299. group.add_argument(
  300. "--data_path_and_name_and_type",
  301. type=str2triple_str,
  302. required=True,
  303. action="append",
  304. )
  305. group.add_argument("--key_file", type=str_or_none)
  306. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  307. group = parser.add_argument_group("The model configuration related")
  308. group.add_argument(
  309. "--vad_infer_config",
  310. type=str,
  311. help="VAD infer configuration",
  312. )
  313. group.add_argument(
  314. "--vad_model_file",
  315. type=str,
  316. help="VAD model parameter file",
  317. )
  318. group.add_argument(
  319. "--vad_cmvn_file",
  320. type=str,
  321. help="Global CMVN file",
  322. )
  323. group.add_argument(
  324. "--vad_train_config",
  325. type=str,
  326. help="VAD training configuration",
  327. )
  328. group = parser.add_argument_group("The inference configuration related")
  329. group.add_argument(
  330. "--batch_size",
  331. type=int,
  332. default=1,
  333. help="The batch size for inference",
  334. )
  335. return parser
  336. def main(cmd=None):
  337. print(get_commandline_args(), file=sys.stderr)
  338. parser = get_parser()
  339. parser.add_argument(
  340. "--mode",
  341. type=str,
  342. default="vad",
  343. help="The decoding mode",
  344. )
  345. args = parser.parse_args(cmd)
  346. kwargs = vars(args)
  347. kwargs.pop("config", None)
  348. # set logging messages
  349. logging.basicConfig(
  350. level=args.log_level,
  351. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  352. )
  353. logging.info("Decoding args: {}".format(kwargs))
  354. # gpu setting
  355. if args.ngpu > 0:
  356. jobid = int(args.output_dir.split(".")[-1])
  357. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  358. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  359. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  360. inference_pipeline = inference_launch(**kwargs)
  361. return inference_pipeline(kwargs["data_path_and_name_and_type"])
  362. if __name__ == "__main__":
  363. main()