asr_inference.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import argparse
  5. import logging
  6. import sys
  7. from pathlib import Path
  8. from typing import Any
  9. from typing import List
  10. from typing import Optional
  11. from typing import Sequence
  12. from typing import Tuple
  13. from typing import Union
  14. from typing import Dict
  15. import numpy as np
  16. import torch
  17. from typeguard import check_argument_types
  18. from typeguard import check_return_type
  19. from funasr.fileio.datadir_writer import DatadirWriter
  20. from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
  21. from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
  22. from funasr.modules.beam_search.beam_search import BeamSearch
  23. from funasr.modules.beam_search.beam_search import Hypothesis
  24. from funasr.modules.scorers.ctc import CTCPrefixScorer
  25. from funasr.modules.scorers.length_bonus import LengthBonus
  26. from funasr.modules.scorers.scorer_interface import BatchScorerInterface
  27. from funasr.modules.subsampling import TooShortUttError
  28. from funasr.tasks.asr import ASRTask
  29. from funasr.tasks.lm import LMTask
  30. from funasr.text.build_tokenizer import build_tokenizer
  31. from funasr.text.token_id_converter import TokenIDConverter
  32. from funasr.torch_utils.device_funcs import to_device
  33. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  34. from funasr.utils import config_argparse
  35. from funasr.utils.cli_utils import get_commandline_args
  36. from funasr.utils.types import str2bool
  37. from funasr.utils.types import str2triple_str
  38. from funasr.utils.types import str_or_none
  39. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  40. from funasr.models.frontend.wav_frontend import WavFrontend
  41. header_colors = '\033[95m'
  42. end_colors = '\033[0m'
  43. class Speech2Text:
  44. """Speech2Text class
  45. Examples:
  46. >>> import soundfile
  47. >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
  48. >>> audio, rate = soundfile.read("speech.wav")
  49. >>> speech2text(audio)
  50. [(text, token, token_int, hypothesis object), ...]
  51. """
  52. def __init__(
  53. self,
  54. asr_train_config: Union[Path, str] = None,
  55. asr_model_file: Union[Path, str] = None,
  56. cmvn_file: Union[Path, str] = None,
  57. lm_train_config: Union[Path, str] = None,
  58. lm_file: Union[Path, str] = None,
  59. token_type: str = None,
  60. bpemodel: str = None,
  61. device: str = "cpu",
  62. maxlenratio: float = 0.0,
  63. minlenratio: float = 0.0,
  64. batch_size: int = 1,
  65. dtype: str = "float32",
  66. beam_size: int = 20,
  67. ctc_weight: float = 0.5,
  68. lm_weight: float = 1.0,
  69. ngram_weight: float = 0.9,
  70. penalty: float = 0.0,
  71. nbest: int = 1,
  72. streaming: bool = False,
  73. frontend_conf: dict = None,
  74. **kwargs,
  75. ):
  76. assert check_argument_types()
  77. # 1. Build ASR model
  78. scorers = {}
  79. asr_model, asr_train_args = ASRTask.build_model_from_file(
  80. asr_train_config, asr_model_file, cmvn_file, device
  81. )
  82. frontend = None
  83. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  84. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  85. logging.info("asr_model: {}".format(asr_model))
  86. logging.info("asr_train_args: {}".format(asr_train_args))
  87. asr_model.to(dtype=getattr(torch, dtype)).eval()
  88. decoder = asr_model.decoder
  89. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  90. token_list = asr_model.token_list
  91. scorers.update(
  92. decoder=decoder,
  93. ctc=ctc,
  94. length_bonus=LengthBonus(len(token_list)),
  95. )
  96. # 2. Build Language model
  97. if lm_train_config is not None:
  98. lm, lm_train_args = LMTask.build_model_from_file(
  99. lm_train_config, lm_file, device
  100. )
  101. scorers["lm"] = lm.lm
  102. # 3. Build ngram model
  103. # ngram is not supported now
  104. ngram = None
  105. scorers["ngram"] = ngram
  106. # 4. Build BeamSearch object
  107. # transducer is not supported now
  108. beam_search_transducer = None
  109. weights = dict(
  110. decoder=1.0 - ctc_weight,
  111. ctc=ctc_weight,
  112. lm=lm_weight,
  113. ngram=ngram_weight,
  114. length_bonus=penalty,
  115. )
  116. beam_search = BeamSearch(
  117. beam_size=beam_size,
  118. weights=weights,
  119. scorers=scorers,
  120. sos=asr_model.sos,
  121. eos=asr_model.eos,
  122. vocab_size=len(token_list),
  123. token_list=token_list,
  124. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  125. )
  126. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  127. if token_type is None:
  128. token_type = asr_train_args.token_type
  129. if bpemodel is None:
  130. bpemodel = asr_train_args.bpemodel
  131. if token_type is None:
  132. tokenizer = None
  133. elif token_type == "bpe":
  134. if bpemodel is not None:
  135. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  136. else:
  137. tokenizer = None
  138. else:
  139. tokenizer = build_tokenizer(token_type=token_type)
  140. converter = TokenIDConverter(token_list=token_list)
  141. logging.info(f"Text tokenizer: {tokenizer}")
  142. self.asr_model = asr_model
  143. self.asr_train_args = asr_train_args
  144. self.converter = converter
  145. self.tokenizer = tokenizer
  146. self.beam_search = beam_search
  147. self.beam_search_transducer = beam_search_transducer
  148. self.maxlenratio = maxlenratio
  149. self.minlenratio = minlenratio
  150. self.device = device
  151. self.dtype = dtype
  152. self.nbest = nbest
  153. self.frontend = frontend
  154. @torch.no_grad()
  155. def __call__(
  156. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  157. ) -> List[
  158. Tuple[
  159. Optional[str],
  160. List[str],
  161. List[int],
  162. Union[Hypothesis],
  163. ]
  164. ]:
  165. """Inference
  166. Args:
  167. speech: Input speech data
  168. Returns:
  169. text, token, token_int, hyp
  170. """
  171. assert check_argument_types()
  172. # Input as audio signal
  173. if isinstance(speech, np.ndarray):
  174. speech = torch.tensor(speech)
  175. if self.frontend is not None:
  176. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  177. feats = to_device(feats, device=self.device)
  178. feats_len = feats_len.int()
  179. self.asr_model.frontend = None
  180. else:
  181. feats = speech
  182. feats_len = speech_lengths
  183. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  184. batch = {"speech": feats, "speech_lengths": feats_len}
  185. # a. To device
  186. batch = to_device(batch, device=self.device)
  187. # b. Forward Encoder
  188. enc, _ = self.asr_model.encode(**batch)
  189. if isinstance(enc, tuple):
  190. enc = enc[0]
  191. assert len(enc) == 1, len(enc)
  192. # c. Passed the encoder result and the beam search
  193. nbest_hyps = self.beam_search(
  194. x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  195. )
  196. nbest_hyps = nbest_hyps[: self.nbest]
  197. results = []
  198. for hyp in nbest_hyps:
  199. assert isinstance(hyp, (Hypothesis)), type(hyp)
  200. # remove sos/eos and get results
  201. last_pos = -1
  202. if isinstance(hyp.yseq, list):
  203. token_int = hyp.yseq[1:last_pos]
  204. else:
  205. token_int = hyp.yseq[1:last_pos].tolist()
  206. # remove blank symbol id, which is assumed to be 0
  207. token_int = list(filter(lambda x: x != 0, token_int))
  208. # Change integer-ids to tokens
  209. token = self.converter.ids2tokens(token_int)
  210. if self.tokenizer is not None:
  211. text = self.tokenizer.tokens2text(token)
  212. else:
  213. text = None
  214. results.append((text, token, token_int, hyp))
  215. assert check_return_type(results)
  216. return results
  217. def inference(
  218. maxlenratio: float,
  219. minlenratio: float,
  220. batch_size: int,
  221. beam_size: int,
  222. ngpu: int,
  223. ctc_weight: float,
  224. lm_weight: float,
  225. penalty: float,
  226. log_level: Union[int, str],
  227. data_path_and_name_and_type,
  228. asr_train_config: Optional[str],
  229. asr_model_file: Optional[str],
  230. cmvn_file: Optional[str] = None,
  231. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  232. lm_train_config: Optional[str] = None,
  233. lm_file: Optional[str] = None,
  234. token_type: Optional[str] = None,
  235. key_file: Optional[str] = None,
  236. word_lm_train_config: Optional[str] = None,
  237. bpemodel: Optional[str] = None,
  238. allow_variable_data_keys: bool = False,
  239. streaming: bool = False,
  240. output_dir: Optional[str] = None,
  241. dtype: str = "float32",
  242. seed: int = 0,
  243. ngram_weight: float = 0.9,
  244. nbest: int = 1,
  245. num_workers: int = 1,
  246. **kwargs,
  247. ):
  248. inference_pipeline = inference_modelscope(
  249. maxlenratio=maxlenratio,
  250. minlenratio=minlenratio,
  251. batch_size=batch_size,
  252. beam_size=beam_size,
  253. ngpu=ngpu,
  254. ctc_weight=ctc_weight,
  255. lm_weight=lm_weight,
  256. penalty=penalty,
  257. log_level=log_level,
  258. asr_train_config=asr_train_config,
  259. asr_model_file=asr_model_file,
  260. cmvn_file=cmvn_file,
  261. raw_inputs=raw_inputs,
  262. lm_train_config=lm_train_config,
  263. lm_file=lm_file,
  264. token_type=token_type,
  265. key_file=key_file,
  266. word_lm_train_config=word_lm_train_config,
  267. bpemodel=bpemodel,
  268. allow_variable_data_keys=allow_variable_data_keys,
  269. streaming=streaming,
  270. output_dir=output_dir,
  271. dtype=dtype,
  272. seed=seed,
  273. ngram_weight=ngram_weight,
  274. nbest=nbest,
  275. num_workers=num_workers,
  276. **kwargs,
  277. )
  278. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  279. def inference_modelscope(
  280. maxlenratio: float,
  281. minlenratio: float,
  282. batch_size: int,
  283. beam_size: int,
  284. ngpu: int,
  285. ctc_weight: float,
  286. lm_weight: float,
  287. penalty: float,
  288. log_level: Union[int, str],
  289. # data_path_and_name_and_type,
  290. asr_train_config: Optional[str],
  291. asr_model_file: Optional[str],
  292. cmvn_file: Optional[str] = None,
  293. lm_train_config: Optional[str] = None,
  294. lm_file: Optional[str] = None,
  295. token_type: Optional[str] = None,
  296. key_file: Optional[str] = None,
  297. word_lm_train_config: Optional[str] = None,
  298. bpemodel: Optional[str] = None,
  299. allow_variable_data_keys: bool = False,
  300. streaming: bool = False,
  301. output_dir: Optional[str] = None,
  302. dtype: str = "float32",
  303. seed: int = 0,
  304. ngram_weight: float = 0.9,
  305. nbest: int = 1,
  306. num_workers: int = 1,
  307. param_dict: dict = None,
  308. **kwargs,
  309. ):
  310. assert check_argument_types()
  311. if batch_size > 1:
  312. raise NotImplementedError("batch decoding is not implemented")
  313. if word_lm_train_config is not None:
  314. raise NotImplementedError("Word LM is not implemented")
  315. if ngpu > 1:
  316. raise NotImplementedError("only single GPU decoding is supported")
  317. logging.basicConfig(
  318. level=log_level,
  319. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  320. )
  321. if ngpu >= 1 and torch.cuda.is_available():
  322. device = "cuda"
  323. else:
  324. device = "cpu"
  325. # 1. Set random-seed
  326. set_all_random_seed(seed)
  327. # 2. Build speech2text
  328. speech2text_kwargs = dict(
  329. asr_train_config=asr_train_config,
  330. asr_model_file=asr_model_file,
  331. cmvn_file=cmvn_file,
  332. lm_train_config=lm_train_config,
  333. lm_file=lm_file,
  334. token_type=token_type,
  335. bpemodel=bpemodel,
  336. device=device,
  337. maxlenratio=maxlenratio,
  338. minlenratio=minlenratio,
  339. dtype=dtype,
  340. beam_size=beam_size,
  341. ctc_weight=ctc_weight,
  342. lm_weight=lm_weight,
  343. ngram_weight=ngram_weight,
  344. penalty=penalty,
  345. nbest=nbest,
  346. streaming=streaming,
  347. )
  348. logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
  349. speech2text = Speech2Text(**speech2text_kwargs)
  350. def _forward(data_path_and_name_and_type,
  351. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  352. output_dir_v2: Optional[str] = None,
  353. fs: dict = None,
  354. param_dict: dict = None,
  355. ):
  356. # 3. Build data-iterator
  357. if data_path_and_name_and_type is None and raw_inputs is not None:
  358. if isinstance(raw_inputs, torch.Tensor):
  359. raw_inputs = raw_inputs.numpy()
  360. data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
  361. loader = ASRTask.build_streaming_iterator(
  362. data_path_and_name_and_type,
  363. dtype=dtype,
  364. fs=fs,
  365. batch_size=batch_size,
  366. key_file=key_file,
  367. num_workers=num_workers,
  368. preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
  369. collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
  370. allow_variable_data_keys=allow_variable_data_keys,
  371. inference=True,
  372. )
  373. finish_count = 0
  374. file_count = 1
  375. # 7 .Start for-loop
  376. # FIXME(kamo): The output format should be discussed about
  377. asr_result_list = []
  378. output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
  379. if output_path is not None:
  380. writer = DatadirWriter(output_path)
  381. else:
  382. writer = None
  383. for keys, batch in loader:
  384. assert isinstance(batch, dict), type(batch)
  385. assert all(isinstance(s, str) for s in keys), keys
  386. _bs = len(next(iter(batch.values())))
  387. assert len(keys) == _bs, f"{len(keys)} != {_bs}"
  388. # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
  389. # N-best list of (text, token, token_int, hyp_object)
  390. try:
  391. results = speech2text(**batch)
  392. except TooShortUttError as e:
  393. logging.warning(f"Utterance {keys} {e}")
  394. hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
  395. results = [[" ", ["sil"], [2], hyp]] * nbest
  396. # Only supporting batch_size==1
  397. key = keys[0]
  398. for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
  399. # Create a directory: outdir/{n}best_recog
  400. if writer is not None:
  401. ibest_writer = writer[f"{n}best_recog"]
  402. # Write the result to each file
  403. ibest_writer["token"][key] = " ".join(token)
  404. # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
  405. ibest_writer["score"][key] = str(hyp.score)
  406. if text is not None:
  407. text_postprocessed = postprocess_utils.sentence_postprocess(token)
  408. item = {'key': key, 'value': text_postprocessed}
  409. asr_result_list.append(item)
  410. finish_count += 1
  411. asr_utils.print_progress(finish_count / file_count)
  412. if writer is not None:
  413. ibest_writer["text"][key] = text
  414. return asr_result_list
  415. return _forward
  416. def set_parameters(language: str = None,
  417. sample_rate: Union[int, Dict[Any, int]] = None):
  418. if language is not None:
  419. global global_asr_language
  420. global_asr_language = language
  421. if sample_rate is not None:
  422. global global_sample_rate
  423. global_sample_rate = sample_rate
  424. def get_parser():
  425. parser = config_argparse.ArgumentParser(
  426. description="ASR Decoding",
  427. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  428. )
  429. # Note(kamo): Use '_' instead of '-' as separator.
  430. # '-' is confusing if written in yaml.
  431. parser.add_argument(
  432. "--log_level",
  433. type=lambda x: x.upper(),
  434. default="INFO",
  435. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  436. help="The verbose level of logging",
  437. )
  438. parser.add_argument("--output_dir", type=str, required=True)
  439. parser.add_argument(
  440. "--ngpu",
  441. type=int,
  442. default=0,
  443. help="The number of gpus. 0 indicates CPU mode",
  444. )
  445. parser.add_argument(
  446. "--gpuid_list",
  447. type=str,
  448. default="",
  449. help="The visible gpus",
  450. )
  451. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  452. parser.add_argument(
  453. "--dtype",
  454. default="float32",
  455. choices=["float16", "float32", "float64"],
  456. help="Data type",
  457. )
  458. parser.add_argument(
  459. "--num_workers",
  460. type=int,
  461. default=1,
  462. help="The number of workers used for DataLoader",
  463. )
  464. group = parser.add_argument_group("Input data related")
  465. group.add_argument(
  466. "--data_path_and_name_and_type",
  467. type=str2triple_str,
  468. required=False,
  469. action="append",
  470. )
  471. group.add_argument("--raw_inputs", type=list, default=None)
  472. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  473. group.add_argument("--key_file", type=str_or_none)
  474. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  475. group = parser.add_argument_group("The model configuration related")
  476. group.add_argument(
  477. "--asr_train_config",
  478. type=str,
  479. help="ASR training configuration",
  480. )
  481. group.add_argument(
  482. "--asr_model_file",
  483. type=str,
  484. help="ASR model parameter file",
  485. )
  486. group.add_argument(
  487. "--cmvn_file",
  488. type=str,
  489. help="Global cmvn file",
  490. )
  491. group.add_argument(
  492. "--lm_train_config",
  493. type=str,
  494. help="LM training configuration",
  495. )
  496. group.add_argument(
  497. "--lm_file",
  498. type=str,
  499. help="LM parameter file",
  500. )
  501. group.add_argument(
  502. "--word_lm_train_config",
  503. type=str,
  504. help="Word LM training configuration",
  505. )
  506. group.add_argument(
  507. "--word_lm_file",
  508. type=str,
  509. help="Word LM parameter file",
  510. )
  511. group.add_argument(
  512. "--ngram_file",
  513. type=str,
  514. help="N-gram parameter file",
  515. )
  516. group.add_argument(
  517. "--model_tag",
  518. type=str,
  519. help="Pretrained model tag. If specify this option, *_train_config and "
  520. "*_file will be overwritten",
  521. )
  522. group = parser.add_argument_group("Beam-search related")
  523. group.add_argument(
  524. "--batch_size",
  525. type=int,
  526. default=1,
  527. help="The batch size for inference",
  528. )
  529. group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
  530. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  531. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  532. group.add_argument(
  533. "--maxlenratio",
  534. type=float,
  535. default=0.0,
  536. help="Input length ratio to obtain max output length. "
  537. "If maxlenratio=0.0 (default), it uses a end-detect "
  538. "function "
  539. "to automatically find maximum hypothesis lengths."
  540. "If maxlenratio<0.0, its absolute value is interpreted"
  541. "as a constant max output length",
  542. )
  543. group.add_argument(
  544. "--minlenratio",
  545. type=float,
  546. default=0.0,
  547. help="Input length ratio to obtain min output length",
  548. )
  549. group.add_argument(
  550. "--ctc_weight",
  551. type=float,
  552. default=0.5,
  553. help="CTC weight in joint decoding",
  554. )
  555. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  556. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  557. group.add_argument("--streaming", type=str2bool, default=False)
  558. group = parser.add_argument_group("Text converter related")
  559. group.add_argument(
  560. "--token_type",
  561. type=str_or_none,
  562. default=None,
  563. choices=["char", "bpe", None],
  564. help="The token type for ASR model. "
  565. "If not given, refers from the training args",
  566. )
  567. group.add_argument(
  568. "--bpemodel",
  569. type=str_or_none,
  570. default=None,
  571. help="The model path of sentencepiece. "
  572. "If not given, refers from the training args",
  573. )
  574. return parser
  575. def main(cmd=None):
  576. print(get_commandline_args(), file=sys.stderr)
  577. parser = get_parser()
  578. args = parser.parse_args(cmd)
  579. kwargs = vars(args)
  580. kwargs.pop("config", None)
  581. inference(**kwargs)
  582. if __name__ == "__main__":
  583. main()