asr_inference.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import argparse
  5. import logging
  6. import sys
  7. from pathlib import Path
  8. from typing import Any
  9. from typing import List
  10. from typing import Optional
  11. from typing import Sequence
  12. from typing import Tuple
  13. from typing import Union
  14. import numpy as np
  15. import torch
  16. from typeguard import check_argument_types
  17. from typeguard import check_return_type
  18. from funasr.fileio.datadir_writer import DatadirWriter
  19. from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
  20. from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
  21. from funasr.modules.beam_search.beam_search import BeamSearch
  22. from funasr.modules.beam_search.beam_search import Hypothesis
  23. from funasr.modules.scorers.ctc import CTCPrefixScorer
  24. from funasr.modules.scorers.length_bonus import LengthBonus
  25. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  26. from funasr.modules.subsampling import TooShortUttError
  27. from funasr.tasks.asr import ASRTask
  28. from funasr.tasks.lm import LMTask
  29. from funasr.text.build_tokenizer import build_tokenizer
  30. from funasr.text.token_id_converter import TokenIDConverter
  31. from funasr.torch_utils.device_funcs import to_device
  32. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  33. from funasr.utils import config_argparse
  34. from funasr.utils.cli_utils import get_commandline_args
  35. from funasr.utils.types import str2bool
  36. from funasr.utils.types import str2triple_str
  37. from funasr.utils.types import str_or_none
  38. class Speech2Text:
  39. """Speech2Text class
  40. Examples:
  41. >>> import soundfile
  42. >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
  43. >>> audio, rate = soundfile.read("speech.wav")
  44. >>> speech2text(audio)
  45. [(text, token, token_int, hypothesis object), ...]
  46. """
  47. def __init__(
  48. self,
  49. asr_train_config: Union[Path, str] = None,
  50. asr_model_file: Union[Path, str] = None,
  51. lm_train_config: Union[Path, str] = None,
  52. lm_file: Union[Path, str] = None,
  53. token_type: str = None,
  54. bpemodel: str = None,
  55. device: str = "cpu",
  56. maxlenratio: float = 0.0,
  57. minlenratio: float = 0.0,
  58. batch_size: int = 1,
  59. dtype: str = "float32",
  60. beam_size: int = 20,
  61. ctc_weight: float = 0.5,
  62. lm_weight: float = 1.0,
  63. ngram_weight: float = 0.9,
  64. penalty: float = 0.0,
  65. nbest: int = 1,
  66. streaming: bool = False,
  67. **kwargs,
  68. ):
  69. assert check_argument_types()
  70. # 1. Build ASR model
  71. scorers = {}
  72. asr_model, asr_train_args = ASRTask.build_model_from_file(
  73. asr_train_config, asr_model_file, device
  74. )
  75. logging.info("asr_model: {}".format(asr_model))
  76. logging.info("asr_train_args: {}".format(asr_train_args))
  77. asr_model.to(dtype=getattr(torch, dtype)).eval()
  78. decoder = asr_model.decoder
  79. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  80. token_list = asr_model.token_list
  81. scorers.update(
  82. decoder=decoder,
  83. ctc=ctc,
  84. length_bonus=LengthBonus(len(token_list)),
  85. )
  86. # 2. Build Language model
  87. if lm_train_config is not None:
  88. lm, lm_train_args = LMTask.build_model_from_file(
  89. lm_train_config, lm_file, device
  90. )
  91. scorers["lm"] = lm.lm
  92. # 3. Build ngram model
  93. # ngram is not supported now
  94. ngram = None
  95. scorers["ngram"] = ngram
  96. # 4. Build BeamSearch object
  97. # transducer is not supported now
  98. beam_search_transducer = None
  99. weights = dict(
  100. decoder=1.0 - ctc_weight,
  101. ctc=ctc_weight,
  102. lm=lm_weight,
  103. ngram=ngram_weight,
  104. length_bonus=penalty,
  105. )
  106. beam_search = BeamSearch(
  107. beam_size=beam_size,
  108. weights=weights,
  109. scorers=scorers,
  110. sos=asr_model.sos,
  111. eos=asr_model.eos,
  112. vocab_size=len(token_list),
  113. token_list=token_list,
  114. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  115. )
  116. # TODO(karita): make all scorers batchfied
  117. if batch_size == 1:
  118. non_batch = [
  119. k
  120. for k, v in beam_search.full_scorers.items()
  121. if not isinstance(v, BatchScorerInterface)
  122. ]
  123. if len(non_batch) == 0:
  124. if streaming:
  125. beam_search.__class__ = BatchBeamSearchOnlineSim
  126. beam_search.set_streaming_config(asr_train_config)
  127. logging.info(
  128. "BatchBeamSearchOnlineSim implementation is selected."
  129. )
  130. else:
  131. beam_search.__class__ = BatchBeamSearch
  132. logging.info("BatchBeamSearch implementation is selected.")
  133. else:
  134. logging.warning(
  135. f"As non-batch scorers {non_batch} are found, "
  136. f"fall back to non-batch implementation."
  137. )
  138. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  139. for scorer in scorers.values():
  140. if isinstance(scorer, torch.nn.Module):
  141. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  142. logging.info(f"Beam_search: {beam_search}")
  143. logging.info(f"Decoding device={device}, dtype={dtype}")
  144. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  145. if token_type is None:
  146. token_type = asr_train_args.token_type
  147. if bpemodel is None:
  148. bpemodel = asr_train_args.bpemodel
  149. if token_type is None:
  150. tokenizer = None
  151. elif token_type == "bpe":
  152. if bpemodel is not None:
  153. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  154. else:
  155. tokenizer = None
  156. else:
  157. tokenizer = build_tokenizer(token_type=token_type)
  158. converter = TokenIDConverter(token_list=token_list)
  159. logging.info(f"Text tokenizer: {tokenizer}")
  160. self.asr_model = asr_model
  161. self.asr_train_args = asr_train_args
  162. self.converter = converter
  163. self.tokenizer = tokenizer
  164. self.beam_search = beam_search
  165. self.beam_search_transducer = beam_search_transducer
  166. self.maxlenratio = maxlenratio
  167. self.minlenratio = minlenratio
  168. self.device = device
  169. self.dtype = dtype
  170. self.nbest = nbest
  171. @torch.no_grad()
  172. def __call__(
  173. self, speech: Union[torch.Tensor, np.ndarray]
  174. ) -> List[
  175. Tuple[
  176. Optional[str],
  177. List[str],
  178. List[int],
  179. Union[Hypothesis],
  180. ]
  181. ]:
  182. """Inference
  183. Args:
  184. data: Input speech data
  185. Returns:
  186. text, token, token_int, hyp
  187. """
  188. assert check_argument_types()
  189. # Input as audio signal
  190. if isinstance(speech, np.ndarray):
  191. speech = torch.tensor(speech)
  192. # data: (Nsamples,) -> (1, Nsamples)
  193. speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
  194. # lengths: (1,)
  195. lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
  196. batch = {"speech": speech, "speech_lengths": lengths}
  197. # a. To device
  198. batch = to_device(batch, device=self.device)
  199. # b. Forward Encoder
  200. enc, _ = self.asr_model.encode(**batch)
  201. if isinstance(enc, tuple):
  202. enc = enc[0]
  203. assert len(enc) == 1, len(enc)
  204. # c. Passed the encoder result and the beam search
  205. nbest_hyps = self.beam_search(
  206. x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  207. )
  208. nbest_hyps = nbest_hyps[: self.nbest]
  209. results = []
  210. for hyp in nbest_hyps:
  211. assert isinstance(hyp, (Hypothesis)), type(hyp)
  212. # remove sos/eos and get results
  213. last_pos = -1
  214. if isinstance(hyp.yseq, list):
  215. token_int = hyp.yseq[1:last_pos]
  216. else:
  217. token_int = hyp.yseq[1:last_pos].tolist()
  218. # remove blank symbol id, which is assumed to be 0
  219. token_int = list(filter(lambda x: x != 0, token_int))
  220. # Change integer-ids to tokens
  221. token = self.converter.ids2tokens(token_int)
  222. if self.tokenizer is not None:
  223. text = self.tokenizer.tokens2text(token)
  224. else:
  225. text = None
  226. results.append((text, token, token_int, hyp))
  227. assert check_return_type(results)
  228. return results
  229. def inference(
  230. output_dir: str,
  231. maxlenratio: float,
  232. minlenratio: float,
  233. batch_size: int,
  234. dtype: str,
  235. beam_size: int,
  236. ngpu: int,
  237. seed: int,
  238. ctc_weight: float,
  239. lm_weight: float,
  240. ngram_weight: float,
  241. penalty: float,
  242. nbest: int,
  243. num_workers: int,
  244. log_level: Union[int, str],
  245. data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
  246. key_file: Optional[str],
  247. asr_train_config: Optional[str],
  248. asr_model_file: Optional[str],
  249. lm_train_config: Optional[str],
  250. lm_file: Optional[str],
  251. word_lm_train_config: Optional[str],
  252. token_type: Optional[str],
  253. bpemodel: Optional[str],
  254. allow_variable_data_keys: bool,
  255. streaming: bool,
  256. **kwargs,
  257. ):
  258. assert check_argument_types()
  259. if batch_size > 1:
  260. raise NotImplementedError("batch decoding is not implemented")
  261. if word_lm_train_config is not None:
  262. raise NotImplementedError("Word LM is not implemented")
  263. if ngpu > 1:
  264. raise NotImplementedError("only single GPU decoding is supported")
  265. logging.basicConfig(
  266. level=log_level,
  267. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  268. )
  269. if ngpu >= 1:
  270. device = "cuda"
  271. else:
  272. device = "cpu"
  273. # 1. Set random-seed
  274. set_all_random_seed(seed)
  275. # 2. Build speech2text
  276. speech2text_kwargs = dict(
  277. asr_train_config=asr_train_config,
  278. asr_model_file=asr_model_file,
  279. lm_train_config=lm_train_config,
  280. lm_file=lm_file,
  281. token_type=token_type,
  282. bpemodel=bpemodel,
  283. device=device,
  284. maxlenratio=maxlenratio,
  285. minlenratio=minlenratio,
  286. dtype=dtype,
  287. beam_size=beam_size,
  288. ctc_weight=ctc_weight,
  289. lm_weight=lm_weight,
  290. ngram_weight=ngram_weight,
  291. penalty=penalty,
  292. nbest=nbest,
  293. streaming=streaming,
  294. )
  295. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  296. speech2text = Speech2Text(**speech2text_kwargs)
  297. # 3. Build data-iterator
  298. loader = ASRTask.build_streaming_iterator(
  299. data_path_and_name_and_type,
  300. dtype=dtype,
  301. batch_size=batch_size,
  302. key_file=key_file,
  303. num_workers=num_workers,
  304. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  305. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  306. allow_variable_data_keys=allow_variable_data_keys,
  307. inference=True,
  308. )
  309. # 7 .Start for-loop
  310. # FIXME(kamo): The output format should be discussed about
  311. with DatadirWriter(output_dir) as writer:
  312. for keys, batch in loader:
  313. assert isinstance(batch, dict), type(batch)
  314. assert all(isinstance(s, str) for s in keys), keys
  315. _bs = len(next(iter(batch.values())))
  316. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  317. batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  318. # N-best list of (text, token, token_int, hyp_object)
  319. try:
  320. results = speech2text(**batch)
  321. except TooShortUttError as e:
  322. logging.warning(f"Utterance {keys} {e}")
  323. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  324. results = [[" ", ["<space>"], [2], hyp]] * nbest
  325. # Only supporting batch_size==1
  326. key = keys[0]
  327. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  328. # Create a directory: outdir/{n}best_recog
  329. ibest_writer = writer[f"{n}best_recog"]
  330. # Write the result to each file
  331. ibest_writer["token"][key] = " ".join(token)
  332. ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  333. ibest_writer["score"][key] = str(hyp.score)
  334. if text is not None:
  335. ibest_writer["text"][key] = text
  336. def get_parser():
  337. parser = config_argparse.ArgumentParser(
  338. description="ASR Decoding",
  339. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  340. )
  341. # Note(kamo): Use '_' instead of '-' as separator.
  342. # '-' is confusing if written in yaml.
  343. parser.add_argument(
  344. "--log_level",
  345. type=lambda x: x.upper(),
  346. default="INFO",
  347. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  348. help="The verbose level of logging",
  349. )
  350. parser.add_argument("--output_dir", type=str, required=True)
  351. parser.add_argument(
  352. "--ngpu",
  353. type=int,
  354. default=0,
  355. help="The number of gpus. 0 indicates CPU mode",
  356. )
  357. parser.add_argument(
  358. "--gpuid_list",
  359. type=str,
  360. default="",
  361. help="The visible gpus",
  362. )
  363. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  364. parser.add_argument(
  365. "--dtype",
  366. default="float32",
  367. choices=["float16", "float32", "float64"],
  368. help="Data type",
  369. )
  370. parser.add_argument(
  371. "--num_workers",
  372. type=int,
  373. default=1,
  374. help="The number of workers used for DataLoader",
  375. )
  376. group = parser.add_argument_group("Input data related")
  377. group.add_argument(
  378. "--data_path_and_name_and_type",
  379. type=str2triple_str,
  380. required=True,
  381. action="append",
  382. )
  383. group.add_argument("--key_file", type=str_or_none)
  384. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  385. group = parser.add_argument_group("The model configuration related")
  386. group.add_argument(
  387. "--asr_train_config",
  388. type=str,
  389. help="ASR training configuration",
  390. )
  391. group.add_argument(
  392. "--asr_model_file",
  393. type=str,
  394. help="ASR model parameter file",
  395. )
  396. group.add_argument(
  397. "--lm_train_config",
  398. type=str,
  399. help="LM training configuration",
  400. )
  401. group.add_argument(
  402. "--lm_file",
  403. type=str,
  404. help="LM parameter file",
  405. )
  406. group.add_argument(
  407. "--word_lm_train_config",
  408. type=str,
  409. help="Word LM training configuration",
  410. )
  411. group.add_argument(
  412. "--word_lm_file",
  413. type=str,
  414. help="Word LM parameter file",
  415. )
  416. group.add_argument(
  417. "--ngram_file",
  418. type=str,
  419. help="N-gram parameter file",
  420. )
  421. group.add_argument(
  422. "--model_tag",
  423. type=str,
  424. help="Pretrained model tag. If specify this option, *_train_config and "
  425. "*_file will be overwritten",
  426. )
  427. group = parser.add_argument_group("Beam-search related")
  428. group.add_argument(
  429. "--batch_size",
  430. type=int,
  431. default=1,
  432. help="The batch size for inference",
  433. )
  434. group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
  435. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  436. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  437. group.add_argument(
  438. "--maxlenratio",
  439. type=float,
  440. default=0.0,
  441. help="Input length ratio to obtain max output length. "
  442. "If maxlenratio=0.0 (default), it uses a end-detect "
  443. "function "
  444. "to automatically find maximum hypothesis lengths."
  445. "If maxlenratio<0.0, its absolute value is interpreted"
  446. "as a constant max output length",
  447. )
  448. group.add_argument(
  449. "--minlenratio",
  450. type=float,
  451. default=0.0,
  452. help="Input length ratio to obtain min output length",
  453. )
  454. group.add_argument(
  455. "--ctc_weight",
  456. type=float,
  457. default=0.5,
  458. help="CTC weight in joint decoding",
  459. )
  460. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  461. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  462. group.add_argument("--streaming", type=str2bool, default=False)
  463. group = parser.add_argument_group("Text converter related")
  464. group.add_argument(
  465. "--token_type",
  466. type=str_or_none,
  467. default=None,
  468. choices=["char", "bpe", None],
  469. help="The token type for ASR model. "
  470. "If not given, refers from the training args",
  471. )
  472. group.add_argument(
  473. "--bpemodel",
  474. type=str_or_none,
  475. default=None,
  476. help="The model path of sentencepiece. "
  477. "If not given, refers from the training args",
  478. )
  479. return parser
  480. def main(cmd=None):
  481. print(get_commandline_args(), file=sys.stderr)
  482. parser = get_parser()
  483. args = parser.parse_args(cmd)
  484. kwargs = vars(args)
  485. kwargs.pop("config", None)
  486. inference(**kwargs)
  487. if __name__ == "__main__":
  488. main()