asr_inference_paraformer_vad_punc.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. #!/usr/bin/env python3
  2. import json
  3. import argparse
  4. import logging
  5. import sys
  6. import time
  7. from pathlib import Path
  8. from typing import Optional
  9. from typing import Sequence
  10. from typing import Tuple
  11. from typing import Union
  12. from typing import Dict
  13. from typing import Any
  14. from typing import List
  15. import math
  16. import copy
  17. import numpy as np
  18. import torch
  19. from typeguard import check_argument_types
  20. from funasr.fileio.datadir_writer import DatadirWriter
  21. from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  22. from funasr.modules.beam_search.beam_search import Hypothesis
  23. from funasr.modules.scorers.ctc import CTCPrefixScorer
  24. from funasr.modules.scorers.length_bonus import LengthBonus
  25. from funasr.modules.subsampling import TooShortUttError
  26. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  27. from funasr.tasks.lm import LMTask
  28. from funasr.text.build_tokenizer import build_tokenizer
  29. from funasr.text.token_id_converter import TokenIDConverter
  30. from funasr.torch_utils.device_funcs import to_device
  31. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  32. from funasr.utils import config_argparse
  33. from funasr.utils.cli_utils import get_commandline_args
  34. from funasr.utils.types import str2bool
  35. from funasr.utils.types import str2triple_str
  36. from funasr.utils.types import str_or_none
  37. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  38. from funasr.models.frontend.wav_frontend import WavFrontend
  39. from funasr.tasks.vad import VADTask
  40. from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
  41. from funasr.bin.punctuation_infer import Text2Punc
  42. from funasr.models.e2e_asr_paraformer import BiCifParaformer
  43. header_colors = '\033[95m'
  44. end_colors = '\033[0m'
  45. class Speech2Text:
  46. """Speech2Text class
  47. Examples:
  48. >>> import soundfile
  49. >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
  50. >>> audio, rate = soundfile.read("speech.wav")
  51. >>> speech2text(audio)
  52. [(text, token, token_int, hypothesis object), ...]
  53. """
  54. def __init__(
  55. self,
  56. asr_train_config: Union[Path, str] = None,
  57. asr_model_file: Union[Path, str] = None,
  58. cmvn_file: Union[Path, str] = None,
  59. lm_train_config: Union[Path, str] = None,
  60. lm_file: Union[Path, str] = None,
  61. token_type: str = None,
  62. bpemodel: str = None,
  63. device: str = "cpu",
  64. maxlenratio: float = 0.0,
  65. minlenratio: float = 0.0,
  66. dtype: str = "float32",
  67. beam_size: int = 20,
  68. ctc_weight: float = 0.5,
  69. lm_weight: float = 1.0,
  70. ngram_weight: float = 0.9,
  71. penalty: float = 0.0,
  72. nbest: int = 1,
  73. frontend_conf: dict = None,
  74. **kwargs,
  75. ):
  76. assert check_argument_types()
  77. # 1. Build ASR model
  78. scorers = {}
  79. asr_model, asr_train_args = ASRTask.build_model_from_file(
  80. asr_train_config, asr_model_file, cmvn_file=cmvn_file, device=device
  81. )
  82. frontend = None
  83. if asr_model.frontend is not None and asr_train_args.frontend_conf is not None:
  84. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  85. # logging.info("asr_model: {}".format(asr_model))
  86. # logging.info("asr_train_args: {}".format(asr_train_args))
  87. asr_model.to(dtype=getattr(torch, dtype)).eval()
  88. if asr_model.ctc != None:
  89. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  90. scorers.update(
  91. ctc=ctc
  92. )
  93. token_list = asr_model.token_list
  94. scorers.update(
  95. length_bonus=LengthBonus(len(token_list)),
  96. )
  97. # 2. Build Language model
  98. if lm_train_config is not None:
  99. lm, lm_train_args = LMTask.build_model_from_file(
  100. lm_train_config, lm_file, device
  101. )
  102. scorers["lm"] = lm.lm
  103. # 3. Build ngram model
  104. # ngram is not supported now
  105. ngram = None
  106. scorers["ngram"] = ngram
  107. # 4. Build BeamSearch object
  108. # transducer is not supported now
  109. beam_search_transducer = None
  110. weights = dict(
  111. decoder=1.0 - ctc_weight,
  112. ctc=ctc_weight,
  113. lm=lm_weight,
  114. ngram=ngram_weight,
  115. length_bonus=penalty,
  116. )
  117. beam_search = BeamSearch(
  118. beam_size=beam_size,
  119. weights=weights,
  120. scorers=scorers,
  121. sos=asr_model.sos,
  122. eos=asr_model.eos,
  123. vocab_size=len(token_list),
  124. token_list=token_list,
  125. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  126. )
  127. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  128. for scorer in scorers.values():
  129. if isinstance(scorer, torch.nn.Module):
  130. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  131. logging.info(f"Decoding device={device}, dtype={dtype}")
  132. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  133. if token_type is None:
  134. token_type = asr_train_args.token_type
  135. if bpemodel is None:
  136. bpemodel = asr_train_args.bpemodel
  137. if token_type is None:
  138. tokenizer = None
  139. elif token_type == "bpe":
  140. if bpemodel is not None:
  141. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  142. else:
  143. tokenizer = None
  144. else:
  145. tokenizer = build_tokenizer(token_type=token_type)
  146. converter = TokenIDConverter(token_list=token_list)
  147. logging.info(f"Text tokenizer: {tokenizer}")
  148. self.asr_model = asr_model
  149. self.asr_train_args = asr_train_args
  150. self.converter = converter
  151. self.tokenizer = tokenizer
  152. is_use_lm = lm_weight != 0.0 and lm_file is not None
  153. if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
  154. beam_search = None
  155. self.beam_search = beam_search
  156. logging.info(f"Beam_search: {self.beam_search}")
  157. self.beam_search_transducer = beam_search_transducer
  158. self.maxlenratio = maxlenratio
  159. self.minlenratio = minlenratio
  160. self.device = device
  161. self.dtype = dtype
  162. self.nbest = nbest
  163. self.frontend = frontend
  164. self.encoder_downsampling_factor = 1
  165. if asr_train_args.encoder_conf["input_layer"] == "conv2d":
  166. self.encoder_downsampling_factor = 4
  167. @torch.no_grad()
  168. def __call__(
  169. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  170. begin_time: int = 0, end_time: int = None,
  171. ):
  172. """Inference
  173. Args:
  174. speech: Input speech data
  175. Returns:
  176. text, token, token_int, hyp
  177. """
  178. assert check_argument_types()
  179. # Input as audio signal
  180. if isinstance(speech, np.ndarray):
  181. speech = torch.tensor(speech)
  182. if self.frontend is not None:
  183. # feats, feats_len = self.frontend.forward(speech, speech_lengths)
  184. # fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
  185. feats, feats_len = self.frontend.forward_lfr_cmvn(speech, speech_lengths)
  186. feats = to_device(feats, device=self.device)
  187. feats_len = feats_len.int()
  188. self.asr_model.frontend = None
  189. else:
  190. feats = speech
  191. feats_len = speech_lengths
  192. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  193. batch = {"speech": feats, "speech_lengths": feats_len}
  194. # a. To device
  195. batch = to_device(batch, device=self.device)
  196. # b. Forward Encoder
  197. enc, enc_len = self.asr_model.encode(**batch)
  198. if isinstance(enc, tuple):
  199. enc = enc[0]
  200. # assert len(enc) == 1, len(enc)
  201. enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
  202. predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
  203. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  204. predictor_outs[2], predictor_outs[3]
  205. pre_token_length = pre_token_length.round().long()
  206. if torch.max(pre_token_length) < 1:
  207. return []
  208. decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
  209. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  210. if isinstance(self.asr_model, BiCifParaformer):
  211. _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
  212. pre_token_length) # test no bias cif2
  213. results = []
  214. b, n, d = decoder_out.size()
  215. for i in range(b):
  216. x = enc[i, :enc_len[i], :]
  217. am_scores = decoder_out[i, :pre_token_length[i], :]
  218. if self.beam_search is not None:
  219. nbest_hyps = self.beam_search(
  220. x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  221. )
  222. nbest_hyps = nbest_hyps[: self.nbest]
  223. else:
  224. yseq = am_scores.argmax(dim=-1)
  225. score = am_scores.max(dim=-1)[0]
  226. score = torch.sum(score, dim=-1)
  227. # pad with mask tokens to ensure compatibility with sos/eos tokens
  228. yseq = torch.tensor(
  229. [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
  230. )
  231. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  232. for hyp in nbest_hyps:
  233. assert isinstance(hyp, (Hypothesis)), type(hyp)
  234. # remove sos/eos and get results
  235. last_pos = -1
  236. if isinstance(hyp.yseq, list):
  237. token_int = hyp.yseq[1:last_pos]
  238. else:
  239. token_int = hyp.yseq[1:last_pos].tolist()
  240. # remove blank symbol id, which is assumed to be 0
  241. token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
  242. # Change integer-ids to tokens
  243. token = self.converter.ids2tokens(token_int)
  244. if self.tokenizer is not None:
  245. text = self.tokenizer.tokens2text(token)
  246. else:
  247. text = None
  248. timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
  249. results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
  250. # assert check_return_type(results)
  251. return results
  252. class Speech2VadSegment:
  253. """Speech2VadSegment class
  254. Examples:
  255. >>> import soundfile
  256. >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
  257. >>> audio, rate = soundfile.read("speech.wav")
  258. >>> speech2segment(audio)
  259. [[10, 230], [245, 450], ...]
  260. """
  261. def __init__(
  262. self,
  263. vad_infer_config: Union[Path, str] = None,
  264. vad_model_file: Union[Path, str] = None,
  265. vad_cmvn_file: Union[Path, str] = None,
  266. device: str = "cpu",
  267. batch_size: int = 1,
  268. dtype: str = "float32",
  269. **kwargs,
  270. ):
  271. assert check_argument_types()
  272. # 1. Build vad model
  273. vad_model, vad_infer_args = VADTask.build_model_from_file(
  274. vad_infer_config, vad_model_file, device
  275. )
  276. frontend = None
  277. if vad_infer_args.frontend is not None:
  278. frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
  279. # logging.info("vad_model: {}".format(vad_model))
  280. # logging.info("vad_infer_args: {}".format(vad_infer_args))
  281. vad_model.to(dtype=getattr(torch, dtype)).eval()
  282. self.vad_model = vad_model
  283. self.vad_infer_args = vad_infer_args
  284. self.device = device
  285. self.dtype = dtype
  286. self.frontend = frontend
  287. self.batch_size = batch_size
  288. @torch.no_grad()
  289. def __call__(
  290. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  291. ) -> List[List[int]]:
  292. """Inference
  293. Args:
  294. speech: Input speech data
  295. Returns:
  296. text, token, token_int, hyp
  297. """
  298. assert check_argument_types()
  299. # Input as audio signal
  300. if isinstance(speech, np.ndarray):
  301. speech = torch.tensor(speech)
  302. if self.frontend is not None:
  303. self.frontend.filter_length_max = math.inf
  304. fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
  305. feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
  306. fbanks = to_device(fbanks, device=self.device)
  307. feats = to_device(feats, device=self.device)
  308. feats_len = feats_len.int()
  309. else:
  310. raise Exception("Need to extract feats first, please configure frontend configuration")
  311. # b. Forward Encoder streaming
  312. t_offset = 0
  313. step = min(feats_len, 6000)
  314. segments = [[]] * self.batch_size
  315. for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
  316. if t_offset + step >= feats_len - 1:
  317. step = feats_len - t_offset
  318. is_final_send = True
  319. else:
  320. is_final_send = False
  321. batch = {
  322. "feats": feats[:, t_offset:t_offset + step, :],
  323. "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
  324. "is_final_send": is_final_send
  325. }
  326. # a. To device
  327. batch = to_device(batch, device=self.device)
  328. segments_part = self.vad_model(**batch)
  329. if segments_part:
  330. for batch_num in range(0, self.batch_size):
  331. segments[batch_num] += segments_part[batch_num]
  332. return fbanks, segments
  333. def inference(
  334. maxlenratio: float,
  335. minlenratio: float,
  336. batch_size: int,
  337. beam_size: int,
  338. ngpu: int,
  339. ctc_weight: float,
  340. lm_weight: float,
  341. penalty: float,
  342. log_level: Union[int, str],
  343. data_path_and_name_and_type,
  344. asr_train_config: Optional[str],
  345. asr_model_file: Optional[str],
  346. cmvn_file: Optional[str] = None,
  347. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  348. lm_train_config: Optional[str] = None,
  349. lm_file: Optional[str] = None,
  350. token_type: Optional[str] = None,
  351. key_file: Optional[str] = None,
  352. word_lm_train_config: Optional[str] = None,
  353. bpemodel: Optional[str] = None,
  354. allow_variable_data_keys: bool = False,
  355. streaming: bool = False,
  356. output_dir: Optional[str] = None,
  357. dtype: str = "float32",
  358. seed: int = 0,
  359. ngram_weight: float = 0.9,
  360. nbest: int = 1,
  361. num_workers: int = 1,
  362. vad_infer_config: Optional[str] = None,
  363. vad_model_file: Optional[str] = None,
  364. vad_cmvn_file: Optional[str] = None,
  365. time_stamp_writer: bool = False,
  366. punc_infer_config: Optional[str] = None,
  367. punc_model_file: Optional[str] = None,
  368. **kwargs,
  369. ):
  370. inference_pipeline = inference_modelscope(
  371. maxlenratio=maxlenratio,
  372. minlenratio=minlenratio,
  373. batch_size=batch_size,
  374. beam_size=beam_size,
  375. ngpu=ngpu,
  376. ctc_weight=ctc_weight,
  377. lm_weight=lm_weight,
  378. penalty=penalty,
  379. log_level=log_level,
  380. asr_train_config=asr_train_config,
  381. asr_model_file=asr_model_file,
  382. cmvn_file=cmvn_file,
  383. raw_inputs=raw_inputs,
  384. lm_train_config=lm_train_config,
  385. lm_file=lm_file,
  386. token_type=token_type,
  387. key_file=key_file,
  388. word_lm_train_config=word_lm_train_config,
  389. bpemodel=bpemodel,
  390. allow_variable_data_keys=allow_variable_data_keys,
  391. streaming=streaming,
  392. output_dir=output_dir,
  393. dtype=dtype,
  394. seed=seed,
  395. ngram_weight=ngram_weight,
  396. nbest=nbest,
  397. num_workers=num_workers,
  398. vad_infer_config=vad_infer_config,
  399. vad_model_file=vad_model_file,
  400. vad_cmvn_file=vad_cmvn_file,
  401. time_stamp_writer=time_stamp_writer,
  402. punc_infer_config=punc_infer_config,
  403. punc_model_file=punc_model_file,
  404. **kwargs,
  405. )
  406. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  407. def inference_modelscope(
  408. maxlenratio: float,
  409. minlenratio: float,
  410. batch_size: int,
  411. beam_size: int,
  412. ngpu: int,
  413. ctc_weight: float,
  414. lm_weight: float,
  415. penalty: float,
  416. log_level: Union[int, str],
  417. # data_path_and_name_and_type,
  418. asr_train_config: Optional[str],
  419. asr_model_file: Optional[str],
  420. cmvn_file: Optional[str] = None,
  421. lm_train_config: Optional[str] = None,
  422. lm_file: Optional[str] = None,
  423. token_type: Optional[str] = None,
  424. key_file: Optional[str] = None,
  425. word_lm_train_config: Optional[str] = None,
  426. bpemodel: Optional[str] = None,
  427. allow_variable_data_keys: bool = False,
  428. output_dir: Optional[str] = None,
  429. dtype: str = "float32",
  430. seed: int = 0,
  431. ngram_weight: float = 0.9,
  432. nbest: int = 1,
  433. num_workers: int = 1,
  434. vad_infer_config: Optional[str] = None,
  435. vad_model_file: Optional[str] = None,
  436. vad_cmvn_file: Optional[str] = None,
  437. time_stamp_writer: bool = True,
  438. punc_infer_config: Optional[str] = None,
  439. punc_model_file: Optional[str] = None,
  440. outputs_dict: Optional[bool] = True,
  441. param_dict: dict = None,
  442. **kwargs,
  443. ):
  444. assert check_argument_types()
  445. if word_lm_train_config is not None:
  446. raise NotImplementedError("Word LM is not implemented")
  447. if ngpu > 1:
  448. raise NotImplementedError("only single GPU decoding is supported")
  449. logging.basicConfig(
  450. level=log_level,
  451. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  452. )
  453. if ngpu >= 1 and torch.cuda.is_available():
  454. device = "cuda"
  455. else:
  456. device = "cpu"
  457. # 1. Set random-seed
  458. set_all_random_seed(seed)
  459. # 2. Build speech2vadsegment
  460. speech2vadsegment_kwargs = dict(
  461. vad_infer_config=vad_infer_config,
  462. vad_model_file=vad_model_file,
  463. vad_cmvn_file=vad_cmvn_file,
  464. device=device,
  465. dtype=dtype,
  466. )
  467. # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
  468. speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
  469. # 3. Build speech2text
  470. speech2text_kwargs = dict(
  471. asr_train_config=asr_train_config,
  472. asr_model_file=asr_model_file,
  473. cmvn_file=cmvn_file,
  474. lm_train_config=lm_train_config,
  475. lm_file=lm_file,
  476. token_type=token_type,
  477. bpemodel=bpemodel,
  478. device=device,
  479. maxlenratio=maxlenratio,
  480. minlenratio=minlenratio,
  481. dtype=dtype,
  482. beam_size=beam_size,
  483. ctc_weight=ctc_weight,
  484. lm_weight=lm_weight,
  485. ngram_weight=ngram_weight,
  486. penalty=penalty,
  487. nbest=nbest,
  488. )
  489. speech2text = Speech2Text(**speech2text_kwargs)
  490. text2punc = None
  491. if punc_model_file is not None:
  492. text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
  493. if output_dir is not None:
  494. writer = DatadirWriter(output_dir)
  495. ibest_writer = writer[f"1best_recog"]
  496. ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
  497. def _forward(data_path_and_name_and_type,
  498. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  499. output_dir_v2: Optional[str] = None,
  500. fs: dict = None,
  501. param_dict: dict = None,
  502. ):
  503. # 3. Build data-iterator
  504. if data_path_and_name_and_type is None and raw_inputs is not None:
  505. if isinstance(raw_inputs, torch.Tensor):
  506. raw_inputs = raw_inputs.numpy()
  507. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  508. loader = ASRTask.build_streaming_iterator(
  509. data_path_and_name_and_type,
  510. dtype=dtype,
  511. fs=fs,
  512. batch_size=1,
  513. key_file=key_file,
  514. num_workers=num_workers,
  515. preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
  516. collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
  517. allow_variable_data_keys=allow_variable_data_keys,
  518. inference=True,
  519. )
  520. if param_dict is not None:
  521. use_timestamp = param_dict.get('use_timestamp', True)
  522. else:
  523. use_timestamp = True
  524. finish_count = 0
  525. file_count = 1
  526. lfr_factor = 6
  527. # 7 .Start for-loop
  528. asr_result_list = []
  529. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  530. writer = None
  531. if output_path is not None:
  532. writer = DatadirWriter(output_path)
  533. ibest_writer = writer[f"1best_recog"]
  534. for keys, batch in loader:
  535. assert isinstance(batch, dict), type(batch)
  536. assert all(isinstance(s, str) for s in keys), keys
  537. _bs = len(next(iter(batch.values())))
  538. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  539. vad_results = speech2vadsegment(**batch)
  540. fbanks, vadsegments = vad_results[0], vad_results[1]
  541. for i, segments in enumerate(vadsegments):
  542. result_segments = [["", [], [], []]]
  543. for j, segment_idx in enumerate(segments):
  544. bed_idx, end_idx = int(segment_idx[0] / 10), int(segment_idx[1] / 10)
  545. segment = fbanks[:, bed_idx:end_idx, :].to(device)
  546. speech_lengths = torch.Tensor([end_idx - bed_idx]).int().to(device)
  547. batch = {"speech": segment, "speech_lengths": speech_lengths, "begin_time": vadsegments[i][j][0],
  548. "end_time": vadsegments[i][j][1]}
  549. results = speech2text(**batch)
  550. if len(results) < 1:
  551. continue
  552. result_cur = [results[0][:-2]]
  553. if j == 0:
  554. result_segments = result_cur
  555. else:
  556. result_segments = [
  557. [result_segments[0][i] + result_cur[0][i] for i in range(len(result_cur[0]))]]
  558. key = keys[0]
  559. result = result_segments[0]
  560. text, token, token_int = result[0], result[1], result[2]
  561. time_stamp = None if len(result) < 4 else result[3]
  562. if use_timestamp and time_stamp is not None:
  563. postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
  564. else:
  565. postprocessed_result = postprocess_utils.sentence_postprocess(token)
  566. text_postprocessed = ""
  567. time_stamp_postprocessed = ""
  568. text_postprocessed_punc = postprocessed_result
  569. if len(postprocessed_result) == 3:
  570. text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
  571. postprocessed_result[1], \
  572. postprocessed_result[2]
  573. else:
  574. text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
  575. text_postprocessed_punc = text_postprocessed
  576. if len(word_lists) > 0 and text2punc is not None:
  577. text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
  578. item = {'key': key, 'value': text_postprocessed_punc}
  579. if text_postprocessed != "":
  580. item['text_postprocessed'] = text_postprocessed
  581. if time_stamp_postprocessed != "":
  582. item['time_stamp'] = time_stamp_postprocessed
  583. asr_result_list.append(item)
  584. finish_count += 1
  585. # asr_utils.print_progress(finish_count / file_count)
  586. if writer is not None:
  587. # Write the result to each file
  588. ibest_writer["token"][key] = " ".join(token)
  589. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  590. ibest_writer["vad"][key] = "{}".format(vadsegments)
  591. ibest_writer["text"][key] = text_postprocessed
  592. ibest_writer["text_with_punc"][key] = text_postprocessed_punc
  593. if time_stamp_postprocessed is not None:
  594. ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
  595. logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
  596. return asr_result_list
  597. return _forward
  598. def get_parser():
  599. parser = config_argparse.ArgumentParser(
  600. description="ASR Decoding",
  601. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  602. )
  603. # Note(kamo): Use '_' instead of '-' as separator.
  604. # '-' is confusing if written in yaml.
  605. parser.add_argument(
  606. "--log_level",
  607. type=lambda x: x.upper(),
  608. default="INFO",
  609. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  610. help="The verbose level of logging",
  611. )
  612. parser.add_argument("--output_dir", type=str, required=True)
  613. parser.add_argument(
  614. "--ngpu",
  615. type=int,
  616. default=0,
  617. help="The number of gpus. 0 indicates CPU mode",
  618. )
  619. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  620. parser.add_argument(
  621. "--dtype",
  622. default="float32",
  623. choices=["float16", "float32", "float64"],
  624. help="Data type",
  625. )
  626. parser.add_argument(
  627. "--num_workers",
  628. type=int,
  629. default=1,
  630. help="The number of workers used for DataLoader",
  631. )
  632. group = parser.add_argument_group("Input data related")
  633. group.add_argument(
  634. "--data_path_and_name_and_type",
  635. type=str2triple_str,
  636. required=False,
  637. action="append",
  638. )
  639. group.add_argument("--key_file", type=str_or_none)
  640. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  641. group = parser.add_argument_group("The model configuration related")
  642. group.add_argument(
  643. "--asr_train_config",
  644. type=str,
  645. help="ASR training configuration",
  646. )
  647. group.add_argument(
  648. "--asr_model_file",
  649. type=str,
  650. help="ASR model parameter file",
  651. )
  652. group.add_argument(
  653. "--cmvn_file",
  654. type=str,
  655. help="Global cmvn file",
  656. )
  657. group.add_argument(
  658. "--lm_train_config",
  659. type=str,
  660. help="LM training configuration",
  661. )
  662. group.add_argument(
  663. "--lm_file",
  664. type=str,
  665. help="LM parameter file",
  666. )
  667. group.add_argument(
  668. "--word_lm_train_config",
  669. type=str,
  670. help="Word LM training configuration",
  671. )
  672. group.add_argument(
  673. "--word_lm_file",
  674. type=str,
  675. help="Word LM parameter file",
  676. )
  677. group.add_argument(
  678. "--ngram_file",
  679. type=str,
  680. help="N-gram parameter file",
  681. )
  682. group.add_argument(
  683. "--model_tag",
  684. type=str,
  685. help="Pretrained model tag. If specify this option, *_train_config and "
  686. "*_file will be overwritten",
  687. )
  688. group = parser.add_argument_group("Beam-search related")
  689. group.add_argument(
  690. "--batch_size",
  691. type=int,
  692. default=1,
  693. help="The batch size for inference",
  694. )
  695. group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
  696. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  697. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  698. group.add_argument(
  699. "--maxlenratio",
  700. type=float,
  701. default=0.0,
  702. help="Input length ratio to obtain max output length. "
  703. "If maxlenratio=0.0 (default), it uses a end-detect "
  704. "function "
  705. "to automatically find maximum hypothesis lengths."
  706. "If maxlenratio<0.0, its absolute value is interpreted"
  707. "as a constant max output length",
  708. )
  709. group.add_argument(
  710. "--minlenratio",
  711. type=float,
  712. default=0.0,
  713. help="Input length ratio to obtain min output length",
  714. )
  715. group.add_argument(
  716. "--ctc_weight",
  717. type=float,
  718. default=0.5,
  719. help="CTC weight in joint decoding",
  720. )
  721. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  722. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  723. group.add_argument("--streaming", type=str2bool, default=False)
  724. group.add_argument("--time_stamp_writer", type=str2bool, default=False)
  725. group.add_argument(
  726. "--frontend_conf",
  727. default=None,
  728. help="",
  729. )
  730. group.add_argument("--raw_inputs", type=list, default=None)
  731. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  732. group = parser.add_argument_group("Text converter related")
  733. group.add_argument(
  734. "--token_type",
  735. type=str_or_none,
  736. default=None,
  737. choices=["char", "bpe", None],
  738. help="The token type for ASR model. "
  739. "If not given, refers from the training args",
  740. )
  741. group.add_argument(
  742. "--bpemodel",
  743. type=str_or_none,
  744. default=None,
  745. help="The model path of sentencepiece. "
  746. "If not given, refers from the training args",
  747. )
  748. group.add_argument(
  749. "--vad_infer_config",
  750. type=str,
  751. help="VAD infer configuration",
  752. )
  753. group.add_argument(
  754. "--vad_model_file",
  755. type=str,
  756. help="VAD model parameter file",
  757. )
  758. group.add_argument(
  759. "--vad_cmvn_file",
  760. type=str,
  761. help="vad, Global cmvn file",
  762. )
  763. group.add_argument(
  764. "--punc_infer_config",
  765. type=str,
  766. help="VAD infer configuration",
  767. )
  768. group.add_argument(
  769. "--punc_model_file",
  770. type=str,
  771. help="VAD model parameter file",
  772. )
  773. return parser
  774. def main(cmd=None):
  775. print(get_commandline_args(), file=sys.stderr)
  776. parser = get_parser()
  777. args = parser.parse_args(cmd)
  778. kwargs = vars(args)
  779. kwargs.pop("config", None)
  780. inference(**kwargs)
  781. if __name__ == "__main__":
  782. main()