asr_inference_paraformer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import sys
  5. import time
  6. from pathlib import Path
  7. from typing import Optional
  8. from typing import Sequence
  9. from typing import Tuple
  10. from typing import Union
  11. import numpy as np
  12. import torch
  13. from typeguard import check_argument_types
  14. from funasr.fileio.datadir_writer import DatadirWriter
  15. from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  16. from funasr.modules.beam_search.beam_search import Hypothesis
  17. from funasr.modules.scorers.ctc import CTCPrefixScorer
  18. from funasr.modules.scorers.length_bonus import LengthBonus
  19. from funasr.modules.subsampling import TooShortUttError
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. from funasr.tasks.lm import LMTask
  22. from funasr.text.build_tokenizer import build_tokenizer
  23. from funasr.text.token_id_converter import TokenIDConverter
  24. from funasr.torch_utils.device_funcs import to_device
  25. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  26. from funasr.utils import config_argparse
  27. from funasr.utils.cli_utils import get_commandline_args
  28. from funasr.utils.types import str2bool
  29. from funasr.utils.types import str2triple_str
  30. from funasr.utils.types import str_or_none
  31. class Speech2Text:
  32. """Speech2Text class
  33. Examples:
  34. >>> import soundfile
  35. >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
  36. >>> audio, rate = soundfile.read("speech.wav")
  37. >>> speech2text(audio)
  38. [(text, token, token_int, hypothesis object), ...]
  39. """
  40. def __init__(
  41. self,
  42. asr_train_config: Union[Path, str] = None,
  43. asr_model_file: Union[Path, str] = None,
  44. lm_train_config: Union[Path, str] = None,
  45. lm_file: Union[Path, str] = None,
  46. token_type: str = None,
  47. bpemodel: str = None,
  48. device: str = "cpu",
  49. maxlenratio: float = 0.0,
  50. minlenratio: float = 0.0,
  51. dtype: str = "float32",
  52. beam_size: int = 20,
  53. ctc_weight: float = 0.5,
  54. lm_weight: float = 1.0,
  55. ngram_weight: float = 0.9,
  56. penalty: float = 0.0,
  57. nbest: int = 1,
  58. **kwargs,
  59. ):
  60. assert check_argument_types()
  61. # 1. Build ASR model
  62. scorers = {}
  63. asr_model, asr_train_args = ASRTask.build_model_from_file(
  64. asr_train_config, asr_model_file, device
  65. )
  66. logging.info("asr_model: {}".format(asr_model))
  67. logging.info("asr_train_args: {}".format(asr_train_args))
  68. asr_model.to(dtype=getattr(torch, dtype)).eval()
  69. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  70. token_list = asr_model.token_list
  71. scorers.update(
  72. ctc=ctc,
  73. length_bonus=LengthBonus(len(token_list)),
  74. )
  75. # 2. Build Language model
  76. if lm_train_config is not None:
  77. lm, lm_train_args = LMTask.build_model_from_file(
  78. lm_train_config, lm_file, device
  79. )
  80. scorers["lm"] = lm.lm
  81. # 3. Build ngram model
  82. # ngram is not supported now
  83. ngram = None
  84. scorers["ngram"] = ngram
  85. # 4. Build BeamSearch object
  86. # transducer is not supported now
  87. beam_search_transducer = None
  88. weights = dict(
  89. decoder=1.0 - ctc_weight,
  90. ctc=ctc_weight,
  91. lm=lm_weight,
  92. ngram=ngram_weight,
  93. length_bonus=penalty,
  94. )
  95. beam_search = BeamSearch(
  96. beam_size=beam_size,
  97. weights=weights,
  98. scorers=scorers,
  99. sos=asr_model.sos,
  100. eos=asr_model.eos,
  101. vocab_size=len(token_list),
  102. token_list=token_list,
  103. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  104. )
  105. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  106. for scorer in scorers.values():
  107. if isinstance(scorer, torch.nn.Module):
  108. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  109. logging.info(f"Beam_search: {beam_search}")
  110. logging.info(f"Decoding device={device}, dtype={dtype}")
  111. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  112. if token_type is None:
  113. token_type = asr_train_args.token_type
  114. if bpemodel is None:
  115. bpemodel = asr_train_args.bpemodel
  116. if token_type is None:
  117. tokenizer = None
  118. elif token_type == "bpe":
  119. if bpemodel is not None:
  120. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  121. else:
  122. tokenizer = None
  123. else:
  124. tokenizer = build_tokenizer(token_type=token_type)
  125. converter = TokenIDConverter(token_list=token_list)
  126. logging.info(f"Text tokenizer: {tokenizer}")
  127. self.asr_model = asr_model
  128. self.asr_train_args = asr_train_args
  129. self.converter = converter
  130. self.tokenizer = tokenizer
  131. self.beam_search = beam_search
  132. self.beam_search_transducer = beam_search_transducer
  133. self.maxlenratio = maxlenratio
  134. self.minlenratio = minlenratio
  135. self.device = device
  136. self.dtype = dtype
  137. self.nbest = nbest
  138. @torch.no_grad()
  139. def __call__(
  140. self, speech: Union[torch.Tensor, np.ndarray]
  141. ):
  142. """Inference
  143. Args:
  144. data: Input speech data
  145. Returns:
  146. text, token, token_int, hyp
  147. """
  148. assert check_argument_types()
  149. # Input as audio signal
  150. if isinstance(speech, np.ndarray):
  151. speech = torch.tensor(speech)
  152. # data: (Nsamples,) -> (1, Nsamples)
  153. speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  154. lfr_factor = max(1, (speech.size()[-1]//80)-1)
  155. # lengths: (1,)
  156. lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  157. batch = {"speech": speech, "speech_lengths": lengths}
  158. # a. To device
  159. batch = to_device(batch, device=self.device)
  160. # b. Forward Encoder
  161. enc, enc_len = self.asr_model.encode(**batch)
  162. if isinstance(enc, tuple):
  163. enc = enc[0]
  164. assert len(enc) == 1, len(enc)
  165. predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
  166. pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
  167. pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device)
  168. decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
  169. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  170. nbest_hyps = self.beam_search(
  171. x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  172. )
  173. nbest_hyps = nbest_hyps[: self.nbest]
  174. results = []
  175. for hyp in nbest_hyps:
  176. assert isinstance(hyp, (Hypothesis)), type(hyp)
  177. # remove sos/eos and get results
  178. last_pos = -1
  179. if isinstance(hyp.yseq, list):
  180. token_int = hyp.yseq[1:last_pos]
  181. else:
  182. token_int = hyp.yseq[1:last_pos].tolist()
  183. # remove blank symbol id, which is assumed to be 0
  184. token_int = list(filter(lambda x: x != 0, token_int))
  185. # Change integer-ids to tokens
  186. token = self.converter.ids2tokens(token_int)
  187. if self.tokenizer is not None:
  188. text = self.tokenizer.tokens2text(token)
  189. else:
  190. text = None
  191. results.append((text, token, token_int, hyp, speech.size(1), lfr_factor))
  192. # assert check_return_type(results)
  193. return results
  194. def inference(
  195. output_dir: str,
  196. maxlenratio: float,
  197. minlenratio: float,
  198. batch_size: int,
  199. dtype: str,
  200. beam_size: int,
  201. ngpu: int,
  202. seed: int,
  203. ctc_weight: float,
  204. lm_weight: float,
  205. ngram_weight: float,
  206. penalty: float,
  207. nbest: int,
  208. num_workers: int,
  209. log_level: Union[int, str],
  210. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  211. key_file: Optional[str],
  212. asr_train_config: Optional[str],
  213. asr_model_file: Optional[str],
  214. lm_train_config: Optional[str],
  215. lm_file: Optional[str],
  216. word_lm_train_config: Optional[str],
  217. token_type: Optional[str],
  218. bpemodel: Optional[str],
  219. allow_variable_data_keys: bool,
  220. **kwargs,
  221. ):
  222. assert check_argument_types()
  223. if batch_size > 1:
  224. raise NotImplementedError("batch decoding is not implemented")
  225. if word_lm_train_config is not None:
  226. raise NotImplementedError("Word LM is not implemented")
  227. if ngpu > 1:
  228. raise NotImplementedError("only single GPU decoding is supported")
  229. logging.basicConfig(
  230. level=log_level,
  231. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  232. )
  233. if ngpu >= 1:
  234. device = "cuda"
  235. else:
  236. device = "cpu"
  237. # 1. Set random-seed
  238. set_all_random_seed(seed)
  239. # 2. Build speech2text
  240. speech2text_kwargs = dict(
  241. asr_train_config=asr_train_config,
  242. asr_model_file=asr_model_file,
  243. lm_train_config=lm_train_config,
  244. lm_file=lm_file,
  245. token_type=token_type,
  246. bpemodel=bpemodel,
  247. device=device,
  248. maxlenratio=maxlenratio,
  249. minlenratio=minlenratio,
  250. dtype=dtype,
  251. beam_size=beam_size,
  252. ctc_weight=ctc_weight,
  253. lm_weight=lm_weight,
  254. ngram_weight=ngram_weight,
  255. penalty=penalty,
  256. nbest=nbest,
  257. )
  258. speech2text = Speech2Text(**speech2text_kwargs)
  259. # 3. Build data-iterator
  260. loader = ASRTask.build_streaming_iterator(
  261. data_path_and_name_and_type,
  262. dtype=dtype,
  263. batch_size=batch_size,
  264. key_file=key_file,
  265. num_workers=num_workers,
  266. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  267. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  268. allow_variable_data_keys=allow_variable_data_keys,
  269. inference=True,
  270. )
  271. forward_time_total = 0.0
  272. length_total = 0.0
  273. # 7 .Start for-loop
  274. # FIXME(kamo): The output format should be discussed about
  275. with DatadirWriter(output_dir) as writer:
  276. for keys, batch in loader:
  277. assert isinstance(batch, dict), type(batch)
  278. assert all(isinstance(s, str) for s in keys), keys
  279. _bs = len(next(iter(batch.values())))
  280. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  281. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  282. logging.info("decoding, utt_id: {}".format(keys))
  283. # N-best list of (text, token, token_int, hyp_object)
  284. try:
  285. time_beg = time.time()
  286. results = speech2text(**batch)
  287. time_end = time.time()
  288. forward_time = time_end - time_beg
  289. lfr_factor = results[0][-1]
  290. length = results[0][-2]
  291. results = [results[0][:-2]]
  292. forward_time_total += forward_time
  293. length_total += length
  294. logging.info(
  295. "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
  296. format(length, forward_time, 100 * forward_time / (length*lfr_factor)))
  297. except TooShortUttError as e:
  298. logging.warning(f"Utterance {keys} {e}")
  299. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  300. results = [[" ", ["<space>"], [2], hyp]] * nbest
  301. # Only supporting batch_size==1
  302. key = keys[0]
  303. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  304. # Create a directory: outdir/{n}best_recog
  305. ibest_writer = writer[f"{n}best_recog"]
  306. # Write the result to each file
  307. ibest_writer["token"][key] = " ".join(token)
  308. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  309. ibest_writer["score"][key] = str(hyp.score)
  310. if text is not None:
  311. ibest_writer["text"][key] = text
  312. logging.info("decoding, predictions: {}".format(text))
  313. logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
  314. format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
  315. def get_parser():
  316. parser = config_argparse.ArgumentParser(
  317. description="ASR Decoding",
  318. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  319. )
  320. # Note(kamo): Use '_' instead of '-' as separator.
  321. # '-' is confusing if written in yaml.
  322. parser.add_argument(
  323. "--log_level",
  324. type=lambda x: x.upper(),
  325. default="INFO",
  326. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  327. help="The verbose level of logging",
  328. )
  329. parser.add_argument("--output_dir", type=str, required=True)
  330. parser.add_argument(
  331. "--ngpu",
  332. type=int,
  333. default=0,
  334. help="The number of gpus. 0 indicates CPU mode",
  335. )
  336. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  337. parser.add_argument(
  338. "--dtype",
  339. default="float32",
  340. choices=["float16", "float32", "float64"],
  341. help="Data type",
  342. )
  343. parser.add_argument(
  344. "--num_workers",
  345. type=int,
  346. default=1,
  347. help="The number of workers used for DataLoader",
  348. )
  349. group = parser.add_argument_group("Input data related")
  350. group.add_argument(
  351. "--data_path_and_name_and_type",
  352. type=str2triple_str,
  353. required=True,
  354. action="append",
  355. )
  356. group.add_argument("--key_file", type=str_or_none)
  357. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  358. group = parser.add_argument_group("The model configuration related")
  359. group.add_argument(
  360. "--asr_train_config",
  361. type=str,
  362. help="ASR training configuration",
  363. )
  364. group.add_argument(
  365. "--asr_model_file",
  366. type=str,
  367. help="ASR model parameter file",
  368. )
  369. group.add_argument(
  370. "--lm_train_config",
  371. type=str,
  372. help="LM training configuration",
  373. )
  374. group.add_argument(
  375. "--lm_file",
  376. type=str,
  377. help="LM parameter file",
  378. )
  379. group.add_argument(
  380. "--word_lm_train_config",
  381. type=str,
  382. help="Word LM training configuration",
  383. )
  384. group.add_argument(
  385. "--word_lm_file",
  386. type=str,
  387. help="Word LM parameter file",
  388. )
  389. group.add_argument(
  390. "--ngram_file",
  391. type=str,
  392. help="N-gram parameter file",
  393. )
  394. group.add_argument(
  395. "--model_tag",
  396. type=str,
  397. help="Pretrained model tag. If specify this option, *_train_config and "
  398. "*_file will be overwritten",
  399. )
  400. group = parser.add_argument_group("Beam-search related")
  401. group.add_argument(
  402. "--batch_size",
  403. type=int,
  404. default=1,
  405. help="The batch size for inference",
  406. )
  407. group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
  408. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  409. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  410. group.add_argument(
  411. "--maxlenratio",
  412. type=float,
  413. default=0.0,
  414. help="Input length ratio to obtain max output length. "
  415. "If maxlenratio=0.0 (default), it uses a end-detect "
  416. "function "
  417. "to automatically find maximum hypothesis lengths."
  418. "If maxlenratio<0.0, its absolute value is interpreted"
  419. "as a constant max output length",
  420. )
  421. group.add_argument(
  422. "--minlenratio",
  423. type=float,
  424. default=0.0,
  425. help="Input length ratio to obtain min output length",
  426. )
  427. group.add_argument(
  428. "--ctc_weight",
  429. type=float,
  430. default=0.5,
  431. help="CTC weight in joint decoding",
  432. )
  433. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  434. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  435. group.add_argument("--streaming", type=str2bool, default=False)
  436. group.add_argument(
  437. "--frontend_conf",
  438. default=None,
  439. help="",
  440. )
  441. group = parser.add_argument_group("Text converter related")
  442. group.add_argument(
  443. "--token_type",
  444. type=str_or_none,
  445. default=None,
  446. choices=["char", "bpe", None],
  447. help="The token type for ASR model. "
  448. "If not given, refers from the training args",
  449. )
  450. group.add_argument(
  451. "--bpemodel",
  452. type=str_or_none,
  453. default=None,
  454. help="The model path of sentencepiece. "
  455. "If not given, refers from the training args",
  456. )
  457. return parser
  458. def main(cmd=None):
  459. print(get_commandline_args(), file=sys.stderr)
  460. parser = get_parser()
  461. args = parser.parse_args(cmd)
  462. kwargs = vars(args)
  463. kwargs.pop("config", None)
  464. inference(**kwargs)
  465. if __name__ == "__main__":
  466. main()