tp_inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. import argparse
  2. import logging
  3. from optparse import Option
  4. import sys
  5. import json
  6. from pathlib import Path
  7. from typing import Any
  8. from typing import List
  9. from typing import Optional
  10. from typing import Sequence
  11. from typing import Tuple
  12. from typing import Union
  13. from typing import Dict
  14. import numpy as np
  15. import torch
  16. from typeguard import check_argument_types
  17. from funasr.fileio.datadir_writer import DatadirWriter
  18. from funasr.datasets.preprocessor import LMPreprocessor
  19. from funasr.tasks.asr import ASRTaskAligner as ASRTask
  20. from funasr.torch_utils.device_funcs import to_device
  21. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  22. from funasr.utils import config_argparse
  23. from funasr.utils.cli_utils import get_commandline_args
  24. from funasr.utils.types import str2bool
  25. from funasr.utils.types import str2triple_str
  26. from funasr.utils.types import str_or_none
  27. from funasr.models.frontend.wav_frontend import WavFrontend
  28. from funasr.text.token_id_converter import TokenIDConverter
  29. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  30. header_colors = '\033[95m'
  31. end_colors = '\033[0m'
  32. global_asr_language: str = 'zh-cn'
  33. global_sample_rate: Union[int, Dict[Any, int]] = {
  34. 'audio_fs': 16000,
  35. 'model_fs': 16000
  36. }
  37. class SpeechText2Timestamp:
  38. def __init__(
  39. self,
  40. timestamp_infer_config: Union[Path, str] = None,
  41. timestamp_model_file: Union[Path, str] = None,
  42. timestamp_cmvn_file: Union[Path, str] = None,
  43. device: str = "cpu",
  44. dtype: str = "float32",
  45. **kwargs,
  46. ):
  47. assert check_argument_types()
  48. # 1. Build ASR model
  49. tp_model, tp_train_args = ASRTask.build_model_from_file(
  50. timestamp_infer_config, timestamp_model_file, device=device
  51. )
  52. if 'cuda' in device:
  53. tp_model = tp_model.cuda() # force model to cuda
  54. frontend = None
  55. if tp_train_args.frontend is not None:
  56. frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
  57. logging.info("tp_model: {}".format(tp_model))
  58. logging.info("tp_train_args: {}".format(tp_train_args))
  59. tp_model.to(dtype=getattr(torch, dtype)).eval()
  60. logging.info(f"Decoding device={device}, dtype={dtype}")
  61. self.tp_model = tp_model
  62. self.tp_train_args = tp_train_args
  63. token_list = self.tp_model.token_list
  64. self.converter = TokenIDConverter(token_list=token_list)
  65. self.device = device
  66. self.dtype = dtype
  67. self.frontend = frontend
  68. self.encoder_downsampling_factor = 1
  69. if tp_train_args.encoder_conf["input_layer"] == "conv2d":
  70. self.encoder_downsampling_factor = 4
  71. @torch.no_grad()
  72. def __call__(
  73. self,
  74. speech: Union[torch.Tensor, np.ndarray],
  75. speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  76. text_lengths: Union[torch.Tensor, np.ndarray] = None
  77. ):
  78. assert check_argument_types()
  79. # Input as audio signal
  80. if isinstance(speech, np.ndarray):
  81. speech = torch.tensor(speech)
  82. if self.frontend is not None:
  83. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  84. feats = to_device(feats, device=self.device)
  85. feats_len = feats_len.int()
  86. self.tp_model.frontend = None
  87. else:
  88. feats = speech
  89. feats_len = speech_lengths
  90. # lfr_factor = max(1, (feats.size()[-1]//80)-1)
  91. batch = {"speech": feats, "speech_lengths": feats_len}
  92. # a. To device
  93. batch = to_device(batch, device=self.device)
  94. # b. Forward Encoder
  95. enc, enc_len = self.tp_model.encode(**batch)
  96. if isinstance(enc, tuple):
  97. enc = enc[0]
  98. # c. Forward Predictor
  99. _, _, us_alphas, us_peaks = self.tp_model.calc_predictor_timestamp(enc, enc_len, text_lengths.to(self.device)+1)
  100. return us_alphas, us_peaks
  101. def inference(
  102. batch_size: int,
  103. ngpu: int,
  104. log_level: Union[int, str],
  105. data_path_and_name_and_type,
  106. timestamp_infer_config: Optional[str],
  107. timestamp_model_file: Optional[str],
  108. timestamp_cmvn_file: Optional[str] = None,
  109. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  110. key_file: Optional[str] = None,
  111. allow_variable_data_keys: bool = False,
  112. output_dir: Optional[str] = None,
  113. dtype: str = "float32",
  114. seed: int = 0,
  115. num_workers: int = 1,
  116. split_with_space: bool = True,
  117. seg_dict_file: Optional[str] = None,
  118. **kwargs,
  119. ):
  120. inference_pipeline = inference_modelscope(
  121. batch_size=batch_size,
  122. ngpu=ngpu,
  123. log_level=log_level,
  124. timestamp_infer_config=timestamp_infer_config,
  125. timestamp_model_file=timestamp_model_file,
  126. timestamp_cmvn_file=timestamp_cmvn_file,
  127. key_file=key_file,
  128. allow_variable_data_keys=allow_variable_data_keys,
  129. output_dir=output_dir,
  130. dtype=dtype,
  131. seed=seed,
  132. num_workers=num_workers,
  133. split_with_space=split_with_space,
  134. seg_dict_file=seg_dict_file,
  135. **kwargs,
  136. )
  137. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  138. def inference_modelscope(
  139. batch_size: int,
  140. ngpu: int,
  141. log_level: Union[int, str],
  142. # data_path_and_name_and_type,
  143. timestamp_infer_config: Optional[str],
  144. timestamp_model_file: Optional[str],
  145. timestamp_cmvn_file: Optional[str] = None,
  146. # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  147. key_file: Optional[str] = None,
  148. allow_variable_data_keys: bool = False,
  149. output_dir: Optional[str] = None,
  150. dtype: str = "float32",
  151. seed: int = 0,
  152. num_workers: int = 1,
  153. split_with_space: bool = True,
  154. seg_dict_file: Optional[str] = None,
  155. **kwargs,
  156. ):
  157. assert check_argument_types()
  158. ncpu = kwargs.get("ncpu", 1)
  159. torch.set_num_threads(ncpu)
  160. if batch_size > 1:
  161. raise NotImplementedError("batch decoding is not implemented")
  162. if ngpu > 1:
  163. raise NotImplementedError("only single GPU decoding is supported")
  164. logging.basicConfig(
  165. level=log_level,
  166. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  167. )
  168. if ngpu >= 1 and torch.cuda.is_available():
  169. device = "cuda"
  170. else:
  171. device = "cpu"
  172. # 1. Set random-seed
  173. set_all_random_seed(seed)
  174. # 2. Build speech2vadsegment
  175. speechtext2timestamp_kwargs = dict(
  176. timestamp_infer_config=timestamp_infer_config,
  177. timestamp_model_file=timestamp_model_file,
  178. timestamp_cmvn_file=timestamp_cmvn_file,
  179. device=device,
  180. dtype=dtype,
  181. )
  182. logging.info("speechtext2timestamp_kwargs: {}".format(speechtext2timestamp_kwargs))
  183. speechtext2timestamp = SpeechText2Timestamp(**speechtext2timestamp_kwargs)
  184. preprocessor = LMPreprocessor(
  185. train=False,
  186. token_type=speechtext2timestamp.tp_train_args.token_type,
  187. token_list=speechtext2timestamp.tp_train_args.token_list,
  188. bpemodel=None,
  189. text_cleaner=None,
  190. g2p_type=None,
  191. text_name="text",
  192. non_linguistic_symbols=speechtext2timestamp.tp_train_args.non_linguistic_symbols,
  193. split_with_space=split_with_space,
  194. seg_dict_file=seg_dict_file,
  195. )
  196. if output_dir is not None:
  197. writer = DatadirWriter(output_dir)
  198. tp_writer = writer[f"timestamp_prediction"]
  199. # ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  200. else:
  201. tp_writer = None
  202. def _forward(
  203. data_path_and_name_and_type,
  204. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  205. output_dir_v2: Optional[str] = None,
  206. fs: dict = None,
  207. param_dict: dict = None,
  208. **kwargs
  209. ):
  210. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  211. writer = None
  212. if output_path is not None:
  213. writer = DatadirWriter(output_path)
  214. tp_writer = writer[f"timestamp_prediction"]
  215. else:
  216. tp_writer = None
  217. # 3. Build data-iterator
  218. if data_path_and_name_and_type is None and raw_inputs is not None:
  219. if isinstance(raw_inputs, torch.Tensor):
  220. raw_inputs = raw_inputs.numpy()
  221. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  222. loader = ASRTask.build_streaming_iterator(
  223. data_path_and_name_and_type,
  224. dtype=dtype,
  225. batch_size=batch_size,
  226. key_file=key_file,
  227. num_workers=num_workers,
  228. preprocess_fn=preprocessor,
  229. collate_fn=ASRTask.build_collate_fn(speechtext2timestamp.tp_train_args, False),
  230. allow_variable_data_keys=allow_variable_data_keys,
  231. inference=True,
  232. )
  233. tp_result_list = []
  234. for keys, batch in loader:
  235. assert isinstance(batch, dict), type(batch)
  236. assert all(isinstance(s, str) for s in keys), keys
  237. _bs = len(next(iter(batch.values())))
  238. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  239. logging.info("timestamp predicting, utt_id: {}".format(keys))
  240. _batch = {'speech':batch['speech'],
  241. 'speech_lengths':batch['speech_lengths'],
  242. 'text_lengths':batch['text_lengths']}
  243. us_alphas, us_cif_peak = speechtext2timestamp(**_batch)
  244. for batch_id in range(_bs):
  245. key = keys[batch_id]
  246. token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
  247. ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
  248. logging.warning(ts_str)
  249. item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
  250. if tp_writer is not None:
  251. tp_writer["tp_sync"][key+'#'] = ts_str
  252. tp_writer["tp_time"][key+'#'] = str(ts_list)
  253. tp_result_list.append(item)
  254. return tp_result_list
  255. return _forward
  256. def get_parser():
  257. parser = config_argparse.ArgumentParser(
  258. description="Timestamp Prediction Inference",
  259. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  260. )
  261. # Note(kamo): Use '_' instead of '-' as separator.
  262. # '-' is confusing if written in yaml.
  263. parser.add_argument(
  264. "--log_level",
  265. type=lambda x: x.upper(),
  266. default="INFO",
  267. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  268. help="The verbose level of logging",
  269. )
  270. parser.add_argument("--output_dir", type=str, required=False)
  271. parser.add_argument(
  272. "--ngpu",
  273. type=int,
  274. default=0,
  275. help="The number of gpus. 0 indicates CPU mode",
  276. )
  277. parser.add_argument(
  278. "--gpuid_list",
  279. type=str,
  280. default="",
  281. help="The visible gpus",
  282. )
  283. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  284. parser.add_argument(
  285. "--dtype",
  286. default="float32",
  287. choices=["float16", "float32", "float64"],
  288. help="Data type",
  289. )
  290. parser.add_argument(
  291. "--num_workers",
  292. type=int,
  293. default=0,
  294. help="The number of workers used for DataLoader",
  295. )
  296. group = parser.add_argument_group("Input data related")
  297. group.add_argument(
  298. "--data_path_and_name_and_type",
  299. type=str2triple_str,
  300. required=False,
  301. action="append",
  302. )
  303. group.add_argument("--raw_inputs", type=list, default=None)
  304. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  305. group.add_argument("--key_file", type=str_or_none)
  306. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  307. group = parser.add_argument_group("The model configuration related")
  308. group.add_argument(
  309. "--timestamp_infer_config",
  310. type=str,
  311. help="VAD infer configuration",
  312. )
  313. group.add_argument(
  314. "--timestamp_model_file",
  315. type=str,
  316. help="VAD model parameter file",
  317. )
  318. group.add_argument(
  319. "--timestamp_cmvn_file",
  320. type=str,
  321. help="Global cmvn file",
  322. )
  323. group = parser.add_argument_group("infer related")
  324. group.add_argument(
  325. "--batch_size",
  326. type=int,
  327. default=1,
  328. help="The batch size for inference",
  329. )
  330. group.add_argument(
  331. "--seg_dict_file",
  332. type=str,
  333. default=None,
  334. help="The batch size for inference",
  335. )
  336. group.add_argument(
  337. "--split_with_space",
  338. type=bool,
  339. default=False,
  340. help="The batch size for inference",
  341. )
  342. return parser
  343. def main(cmd=None):
  344. print(get_commandline_args(), file=sys.stderr)
  345. parser = get_parser()
  346. args = parser.parse_args(cmd)
  347. kwargs = vars(args)
  348. kwargs.pop("config", None)
  349. inference(**kwargs)
  350. if __name__ == "__main__":
  351. main()