asr_inference_paraformer_streaming.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import sys
  5. import time
  6. import copy
  7. import os
  8. import codecs
  9. import tempfile
  10. import requests
  11. from pathlib import Path
  12. from typing import Optional
  13. from typing import Sequence
  14. from typing import Tuple
  15. from typing import Union
  16. from typing import Dict
  17. from typing import Any
  18. from typing import List
  19. import numpy as np
  20. import torch
  21. import torchaudio
  22. from typeguard import check_argument_types
  23. from funasr.fileio.datadir_writer import DatadirWriter
  24. from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
  25. from funasr.modules.beam_search.beam_search import Hypothesis
  26. from funasr.modules.scorers.ctc import CTCPrefixScorer
  27. from funasr.modules.scorers.length_bonus import LengthBonus
  28. from funasr.modules.subsampling import TooShortUttError
  29. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  30. from funasr.tasks.lm import LMTask
  31. from funasr.text.build_tokenizer import build_tokenizer
  32. from funasr.text.token_id_converter import TokenIDConverter
  33. from funasr.torch_utils.device_funcs import to_device
  34. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  35. from funasr.utils import config_argparse
  36. from funasr.utils.cli_utils import get_commandline_args
  37. from funasr.utils.types import str2bool
  38. from funasr.utils.types import str2triple_str
  39. from funasr.utils.types import str_or_none
  40. from funasr.utils import asr_utils, wav_utils, postprocess_utils
  41. from funasr.models.frontend.wav_frontend import WavFrontend
  42. from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
  43. from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
  44. np.set_printoptions(threshold=np.inf)
  45. class Speech2Text:
  46. """Speech2Text class
  47. Examples:
  48. >>> import soundfile
  49. >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
  50. >>> audio, rate = soundfile.read("speech.wav")
  51. >>> speech2text(audio)
  52. [(text, token, token_int, hypothesis object), ...]
  53. """
  54. def __init__(
  55. self,
  56. asr_train_config: Union[Path, str] = None,
  57. asr_model_file: Union[Path, str] = None,
  58. cmvn_file: Union[Path, str] = None,
  59. lm_train_config: Union[Path, str] = None,
  60. lm_file: Union[Path, str] = None,
  61. token_type: str = None,
  62. bpemodel: str = None,
  63. device: str = "cpu",
  64. maxlenratio: float = 0.0,
  65. minlenratio: float = 0.0,
  66. dtype: str = "float32",
  67. beam_size: int = 20,
  68. ctc_weight: float = 0.5,
  69. lm_weight: float = 1.0,
  70. ngram_weight: float = 0.9,
  71. penalty: float = 0.0,
  72. nbest: int = 1,
  73. frontend_conf: dict = None,
  74. hotword_list_or_file: str = None,
  75. **kwargs,
  76. ):
  77. assert check_argument_types()
  78. # 1. Build ASR model
  79. scorers = {}
  80. asr_model, asr_train_args = ASRTask.build_model_from_file(
  81. asr_train_config, asr_model_file, cmvn_file, device
  82. )
  83. frontend = None
  84. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  85. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  86. logging.info("asr_model: {}".format(asr_model))
  87. logging.info("asr_train_args: {}".format(asr_train_args))
  88. asr_model.to(dtype=getattr(torch, dtype)).eval()
  89. if asr_model.ctc != None:
  90. ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
  91. scorers.update(
  92. ctc=ctc
  93. )
  94. token_list = asr_model.token_list
  95. scorers.update(
  96. length_bonus=LengthBonus(len(token_list)),
  97. )
  98. # 2. Build Language model
  99. if lm_train_config is not None:
  100. lm, lm_train_args = LMTask.build_model_from_file(
  101. lm_train_config, lm_file, device
  102. )
  103. scorers["lm"] = lm.lm
  104. # 3. Build ngram model
  105. # ngram is not supported now
  106. ngram = None
  107. scorers["ngram"] = ngram
  108. # 4. Build BeamSearch object
  109. # transducer is not supported now
  110. beam_search_transducer = None
  111. weights = dict(
  112. decoder=1.0 - ctc_weight,
  113. ctc=ctc_weight,
  114. lm=lm_weight,
  115. ngram=ngram_weight,
  116. length_bonus=penalty,
  117. )
  118. beam_search = BeamSearch(
  119. beam_size=beam_size,
  120. weights=weights,
  121. scorers=scorers,
  122. sos=asr_model.sos,
  123. eos=asr_model.eos,
  124. vocab_size=len(token_list),
  125. token_list=token_list,
  126. pre_beam_score_key=None if ctc_weight == 1.0 else "full",
  127. )
  128. beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
  129. for scorer in scorers.values():
  130. if isinstance(scorer, torch.nn.Module):
  131. scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
  132. logging.info(f"Decoding device={device}, dtype={dtype}")
  133. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  134. if token_type is None:
  135. token_type = asr_train_args.token_type
  136. if bpemodel is None:
  137. bpemodel = asr_train_args.bpemodel
  138. if token_type is None:
  139. tokenizer = None
  140. elif token_type == "bpe":
  141. if bpemodel is not None:
  142. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  143. else:
  144. tokenizer = None
  145. else:
  146. tokenizer = build_tokenizer(token_type=token_type)
  147. converter = TokenIDConverter(token_list=token_list)
  148. logging.info(f"Text tokenizer: {tokenizer}")
  149. self.asr_model = asr_model
  150. self.asr_train_args = asr_train_args
  151. self.converter = converter
  152. self.tokenizer = tokenizer
  153. # 6. [Optional] Build hotword list from str, local file or url
  154. is_use_lm = lm_weight != 0.0 and lm_file is not None
  155. if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
  156. beam_search = None
  157. self.beam_search = beam_search
  158. logging.info(f"Beam_search: {self.beam_search}")
  159. self.beam_search_transducer = beam_search_transducer
  160. self.maxlenratio = maxlenratio
  161. self.minlenratio = minlenratio
  162. self.device = device
  163. self.dtype = dtype
  164. self.nbest = nbest
  165. self.frontend = frontend
  166. self.encoder_downsampling_factor = 1
  167. if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
  168. self.encoder_downsampling_factor = 4
  169. @torch.no_grad()
  170. def __call__(
  171. self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
  172. begin_time: int = 0, end_time: int = None,
  173. ):
  174. """Inference
  175. Args:
  176. speech: Input speech data
  177. Returns:
  178. text, token, token_int, hyp
  179. """
  180. assert check_argument_types()
  181. # Input as audio signal
  182. if isinstance(speech, np.ndarray):
  183. speech = torch.tensor(speech)
  184. if self.frontend is not None:
  185. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  186. feats = to_device(feats, device=self.device)
  187. feats_len = feats_len.int()
  188. self.asr_model.frontend = None
  189. else:
  190. feats = speech
  191. feats_len = speech_lengths
  192. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  193. feats_len = cache["encoder"]["stride"] + cache["encoder"]["pad_left"] + cache["encoder"]["pad_right"]
  194. feats = feats[:,cache["encoder"]["start_idx"]:cache["encoder"]["start_idx"]+feats_len,:]
  195. feats_len = torch.tensor([feats_len])
  196. batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache}
  197. # a. To device
  198. batch = to_device(batch, device=self.device)
  199. # b. Forward Encoder
  200. enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache)
  201. if isinstance(enc, tuple):
  202. enc = enc[0]
  203. # assert len(enc) == 1, len(enc)
  204. enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
  205. predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache)
  206. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  207. predictor_outs[2], predictor_outs[3]
  208. pre_token_length = pre_token_length.floor().long()
  209. if torch.max(pre_token_length) < 1:
  210. return []
  211. decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache)
  212. decoder_out = decoder_outs
  213. results = []
  214. b, n, d = decoder_out.size()
  215. for i in range(b):
  216. x = enc[i, :enc_len[i], :]
  217. am_scores = decoder_out[i, :pre_token_length[i], :]
  218. if self.beam_search is not None:
  219. nbest_hyps = self.beam_search(
  220. x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
  221. )
  222. nbest_hyps = nbest_hyps[: self.nbest]
  223. else:
  224. yseq = am_scores.argmax(dim=-1)
  225. score = am_scores.max(dim=-1)[0]
  226. score = torch.sum(score, dim=-1)
  227. # pad with mask tokens to ensure compatibility with sos/eos tokens
  228. yseq = torch.tensor(
  229. [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
  230. )
  231. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  232. for hyp in nbest_hyps:
  233. assert isinstance(hyp, (Hypothesis)), type(hyp)
  234. # remove sos/eos and get results
  235. last_pos = -1
  236. if isinstance(hyp.yseq, list):
  237. token_int = hyp.yseq[1:last_pos]
  238. else:
  239. token_int = hyp.yseq[1:last_pos].tolist()
  240. # remove blank symbol id, which is assumed to be 0
  241. token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
  242. # Change integer-ids to tokens
  243. token = self.converter.ids2tokens(token_int)
  244. if self.tokenizer is not None:
  245. text = self.tokenizer.tokens2text(token)
  246. else:
  247. text = None
  248. results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
  249. # assert check_return_type(results)
  250. return results
  251. class Speech2TextExport:
  252. """Speech2TextExport class
  253. """
  254. def __init__(
  255. self,
  256. asr_train_config: Union[Path, str] = None,
  257. asr_model_file: Union[Path, str] = None,
  258. cmvn_file: Union[Path, str] = None,
  259. lm_train_config: Union[Path, str] = None,
  260. lm_file: Union[Path, str] = None,
  261. token_type: str = None,
  262. bpemodel: str = None,
  263. device: str = "cpu",
  264. maxlenratio: float = 0.0,
  265. minlenratio: float = 0.0,
  266. dtype: str = "float32",
  267. beam_size: int = 20,
  268. ctc_weight: float = 0.5,
  269. lm_weight: float = 1.0,
  270. ngram_weight: float = 0.9,
  271. penalty: float = 0.0,
  272. nbest: int = 1,
  273. frontend_conf: dict = None,
  274. hotword_list_or_file: str = None,
  275. **kwargs,
  276. ):
  277. # 1. Build ASR model
  278. asr_model, asr_train_args = ASRTask.build_model_from_file(
  279. asr_train_config, asr_model_file, cmvn_file, device
  280. )
  281. frontend = None
  282. if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
  283. frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
  284. logging.info("asr_model: {}".format(asr_model))
  285. logging.info("asr_train_args: {}".format(asr_train_args))
  286. asr_model.to(dtype=getattr(torch, dtype)).eval()
  287. token_list = asr_model.token_list
  288. logging.info(f"Decoding device={device}, dtype={dtype}")
  289. # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
  290. if token_type is None:
  291. token_type = asr_train_args.token_type
  292. if bpemodel is None:
  293. bpemodel = asr_train_args.bpemodel
  294. if token_type is None:
  295. tokenizer = None
  296. elif token_type == "bpe":
  297. if bpemodel is not None:
  298. tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
  299. else:
  300. tokenizer = None
  301. else:
  302. tokenizer = build_tokenizer(token_type=token_type)
  303. converter = TokenIDConverter(token_list=token_list)
  304. logging.info(f"Text tokenizer: {tokenizer}")
  305. # self.asr_model = asr_model
  306. self.asr_train_args = asr_train_args
  307. self.converter = converter
  308. self.tokenizer = tokenizer
  309. self.device = device
  310. self.dtype = dtype
  311. self.nbest = nbest
  312. self.frontend = frontend
  313. model = Paraformer_export(asr_model, onnx=False)
  314. self.asr_model = model
  315. @torch.no_grad()
  316. def __call__(
  317. self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
  318. ):
  319. """Inference
  320. Args:
  321. speech: Input speech data
  322. Returns:
  323. text, token, token_int, hyp
  324. """
  325. assert check_argument_types()
  326. # Input as audio signal
  327. if isinstance(speech, np.ndarray):
  328. speech = torch.tensor(speech)
  329. if self.frontend is not None:
  330. feats, feats_len = self.frontend.forward(speech, speech_lengths)
  331. feats = to_device(feats, device=self.device)
  332. feats_len = feats_len.int()
  333. self.asr_model.frontend = None
  334. else:
  335. feats = speech
  336. feats_len = speech_lengths
  337. enc_len_batch_total = feats_len.sum()
  338. lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
  339. batch = {"speech": feats, "speech_lengths": feats_len}
  340. # a. To device
  341. batch = to_device(batch, device=self.device)
  342. decoder_outs = self.asr_model(**batch)
  343. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  344. results = []
  345. b, n, d = decoder_out.size()
  346. for i in range(b):
  347. am_scores = decoder_out[i, :ys_pad_lens[i], :]
  348. yseq = am_scores.argmax(dim=-1)
  349. score = am_scores.max(dim=-1)[0]
  350. score = torch.sum(score, dim=-1)
  351. # pad with mask tokens to ensure compatibility with sos/eos tokens
  352. yseq = torch.tensor(
  353. yseq.tolist(), device=yseq.device
  354. )
  355. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  356. for hyp in nbest_hyps:
  357. assert isinstance(hyp, (Hypothesis)), type(hyp)
  358. # remove sos/eos and get results
  359. last_pos = -1
  360. if isinstance(hyp.yseq, list):
  361. token_int = hyp.yseq[1:last_pos]
  362. else:
  363. token_int = hyp.yseq[1:last_pos].tolist()
  364. # remove blank symbol id, which is assumed to be 0
  365. token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
  366. # Change integer-ids to tokens
  367. token = self.converter.ids2tokens(token_int)
  368. if self.tokenizer is not None:
  369. text = self.tokenizer.tokens2text(token)
  370. else:
  371. text = None
  372. results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
  373. return results
  374. def inference(
  375. maxlenratio: float,
  376. minlenratio: float,
  377. batch_size: int,
  378. beam_size: int,
  379. ngpu: int,
  380. ctc_weight: float,
  381. lm_weight: float,
  382. penalty: float,
  383. log_level: Union[int, str],
  384. data_path_and_name_and_type,
  385. asr_train_config: Optional[str],
  386. asr_model_file: Optional[str],
  387. cmvn_file: Optional[str] = None,
  388. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  389. lm_train_config: Optional[str] = None,
  390. lm_file: Optional[str] = None,
  391. token_type: Optional[str] = None,
  392. key_file: Optional[str] = None,
  393. word_lm_train_config: Optional[str] = None,
  394. bpemodel: Optional[str] = None,
  395. allow_variable_data_keys: bool = False,
  396. streaming: bool = False,
  397. output_dir: Optional[str] = None,
  398. dtype: str = "float32",
  399. seed: int = 0,
  400. ngram_weight: float = 0.9,
  401. nbest: int = 1,
  402. num_workers: int = 1,
  403. **kwargs,
  404. ):
  405. inference_pipeline = inference_modelscope(
  406. maxlenratio=maxlenratio,
  407. minlenratio=minlenratio,
  408. batch_size=batch_size,
  409. beam_size=beam_size,
  410. ngpu=ngpu,
  411. ctc_weight=ctc_weight,
  412. lm_weight=lm_weight,
  413. penalty=penalty,
  414. log_level=log_level,
  415. asr_train_config=asr_train_config,
  416. asr_model_file=asr_model_file,
  417. cmvn_file=cmvn_file,
  418. raw_inputs=raw_inputs,
  419. lm_train_config=lm_train_config,
  420. lm_file=lm_file,
  421. token_type=token_type,
  422. key_file=key_file,
  423. word_lm_train_config=word_lm_train_config,
  424. bpemodel=bpemodel,
  425. allow_variable_data_keys=allow_variable_data_keys,
  426. streaming=streaming,
  427. output_dir=output_dir,
  428. dtype=dtype,
  429. seed=seed,
  430. ngram_weight=ngram_weight,
  431. nbest=nbest,
  432. num_workers=num_workers,
  433. **kwargs,
  434. )
  435. return inference_pipeline(data_path_and_name_and_type, raw_inputs)
  436. def inference_modelscope(
  437. maxlenratio: float,
  438. minlenratio: float,
  439. batch_size: int,
  440. beam_size: int,
  441. ngpu: int,
  442. ctc_weight: float,
  443. lm_weight: float,
  444. penalty: float,
  445. log_level: Union[int, str],
  446. # data_path_and_name_and_type,
  447. asr_train_config: Optional[str],
  448. asr_model_file: Optional[str],
  449. cmvn_file: Optional[str] = None,
  450. lm_train_config: Optional[str] = None,
  451. lm_file: Optional[str] = None,
  452. token_type: Optional[str] = None,
  453. key_file: Optional[str] = None,
  454. word_lm_train_config: Optional[str] = None,
  455. bpemodel: Optional[str] = None,
  456. allow_variable_data_keys: bool = False,
  457. dtype: str = "float32",
  458. seed: int = 0,
  459. ngram_weight: float = 0.9,
  460. nbest: int = 1,
  461. num_workers: int = 1,
  462. output_dir: Optional[str] = None,
  463. param_dict: dict = None,
  464. **kwargs,
  465. ):
  466. assert check_argument_types()
  467. if word_lm_train_config is not None:
  468. raise NotImplementedError("Word LM is not implemented")
  469. if ngpu > 1:
  470. raise NotImplementedError("only single GPU decoding is supported")
  471. logging.basicConfig(
  472. level=log_level,
  473. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  474. )
  475. export_mode = False
  476. if ngpu >= 1 and torch.cuda.is_available():
  477. device = "cuda"
  478. else:
  479. device = "cpu"
  480. batch_size = 1
  481. # 1. Set random-seed
  482. set_all_random_seed(seed)
  483. # 2. Build speech2text
  484. speech2text_kwargs = dict(
  485. asr_train_config=asr_train_config,
  486. asr_model_file=asr_model_file,
  487. cmvn_file=cmvn_file,
  488. lm_train_config=lm_train_config,
  489. lm_file=lm_file,
  490. token_type=token_type,
  491. bpemodel=bpemodel,
  492. device=device,
  493. maxlenratio=maxlenratio,
  494. minlenratio=minlenratio,
  495. dtype=dtype,
  496. beam_size=beam_size,
  497. ctc_weight=ctc_weight,
  498. lm_weight=lm_weight,
  499. ngram_weight=ngram_weight,
  500. penalty=penalty,
  501. nbest=nbest,
  502. )
  503. if export_mode:
  504. speech2text = Speech2TextExport(**speech2text_kwargs)
  505. else:
  506. speech2text = Speech2Text(**speech2text_kwargs)
  507. def _load_bytes(input):
  508. middle_data = np.frombuffer(input, dtype=np.int16)
  509. middle_data = np.asarray(middle_data)
  510. if middle_data.dtype.kind not in 'iu':
  511. raise TypeError("'middle_data' must be an array of integers")
  512. dtype = np.dtype('float32')
  513. if dtype.kind != 'f':
  514. raise TypeError("'dtype' must be a floating point type")
  515. i = np.iinfo(middle_data.dtype)
  516. abs_max = 2 ** (i.bits - 1)
  517. offset = i.min + abs_max
  518. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  519. return array
  520. def _forward(
  521. data_path_and_name_and_type,
  522. raw_inputs: Union[np.ndarray, torch.Tensor] = None,
  523. output_dir_v2: Optional[str] = None,
  524. fs: dict = None,
  525. param_dict: dict = None,
  526. **kwargs,
  527. ):
  528. # 3. Build data-iterator
  529. is_final = False
  530. cache = {}
  531. if param_dict is not None and "cache" in param_dict:
  532. cache = param_dict["cache"]
  533. if param_dict is not None and "is_final" in param_dict:
  534. is_final = param_dict["is_final"]
  535. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
  536. raw_inputs = _load_bytes(data_path_and_name_and_type[0])
  537. raw_inputs = torch.tensor(raw_inputs)
  538. if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
  539. raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
  540. is_final = True
  541. if data_path_and_name_and_type is None and raw_inputs is not None:
  542. if isinstance(raw_inputs, np.ndarray):
  543. raw_inputs = torch.tensor(raw_inputs)
  544. # 7 .Start for-loop
  545. # FIXME(kamo): The output format should be discussed about
  546. asr_result_list = []
  547. results = []
  548. asr_result = ""
  549. wait = True
  550. if len(cache) == 0:
  551. cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None, "is_final": is_final, "left": 0, "right": 0}
  552. cache_de = {"decode_fsmn": None}
  553. cache["decoder"] = cache_de
  554. cache["first_chunk"] = True
  555. cache["speech"] = []
  556. cache["accum_speech"] = 0
  557. if raw_inputs is not None:
  558. if len(cache["speech"]) == 0:
  559. cache["speech"] = raw_inputs
  560. else:
  561. cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0)
  562. cache["accum_speech"] += len(raw_inputs)
  563. while cache["accum_speech"] >= 960:
  564. if cache["first_chunk"]:
  565. if cache["accum_speech"] >= 14400:
  566. speech = torch.unsqueeze(cache["speech"], axis=0)
  567. speech_length = torch.tensor([len(cache["speech"])])
  568. cache["encoder"]["pad_left"] = 5
  569. cache["encoder"]["pad_right"] = 5
  570. cache["encoder"]["stride"] = 10
  571. cache["encoder"]["left"] = 5
  572. cache["encoder"]["right"] = 0
  573. results = speech2text(cache, speech, speech_length)
  574. cache["accum_speech"] -= 4800
  575. cache["first_chunk"] = False
  576. cache["encoder"]["start_idx"] = -5
  577. cache["encoder"]["is_final"] = False
  578. wait = False
  579. else:
  580. if is_final:
  581. cache["encoder"]["stride"] = len(cache["speech"]) // 960
  582. cache["encoder"]["pad_left"] = 0
  583. cache["encoder"]["pad_right"] = 0
  584. speech = torch.unsqueeze(cache["speech"], axis=0)
  585. speech_length = torch.tensor([len(cache["speech"])])
  586. results = speech2text(cache, speech, speech_length)
  587. cache["accum_speech"] = 0
  588. wait = False
  589. else:
  590. break
  591. else:
  592. if cache["accum_speech"] >= 19200:
  593. cache["encoder"]["start_idx"] += 10
  594. cache["encoder"]["stride"] = 10
  595. cache["encoder"]["pad_left"] = 5
  596. cache["encoder"]["pad_right"] = 5
  597. cache["encoder"]["left"] = 0
  598. cache["encoder"]["right"] = 0
  599. speech = torch.unsqueeze(cache["speech"], axis=0)
  600. speech_length = torch.tensor([len(cache["speech"])])
  601. results = speech2text(cache, speech, speech_length)
  602. cache["accum_speech"] -= 9600
  603. wait = False
  604. else:
  605. if is_final:
  606. cache["encoder"]["is_final"] = True
  607. if cache["accum_speech"] >= 14400:
  608. cache["encoder"]["start_idx"] += 10
  609. cache["encoder"]["stride"] = 10
  610. cache["encoder"]["pad_left"] = 5
  611. cache["encoder"]["pad_right"] = 5
  612. cache["encoder"]["left"] = 0
  613. cache["encoder"]["right"] = cache["accum_speech"] // 960 - 15
  614. speech = torch.unsqueeze(cache["speech"], axis=0)
  615. speech_length = torch.tensor([len(cache["speech"])])
  616. results = speech2text(cache, speech, speech_length)
  617. cache["accum_speech"] -= 9600
  618. wait = False
  619. else:
  620. cache["encoder"]["start_idx"] += 10
  621. cache["encoder"]["stride"] = cache["accum_speech"] // 960 - 5
  622. cache["encoder"]["pad_left"] = 5
  623. cache["encoder"]["pad_right"] = 0
  624. cache["encoder"]["left"] = 0
  625. cache["encoder"]["right"] = 0
  626. speech = torch.unsqueeze(cache["speech"], axis=0)
  627. speech_length = torch.tensor([len(cache["speech"])])
  628. results = speech2text(cache, speech, speech_length)
  629. cache["accum_speech"] = 0
  630. wait = False
  631. else:
  632. break
  633. if len(results) >= 1:
  634. asr_result += results[0][0]
  635. if asr_result == "":
  636. asr_result = "sil"
  637. if wait:
  638. asr_result = "waiting_for_more_voice"
  639. item = {'key': "utt", 'value': asr_result}
  640. asr_result_list.append(item)
  641. else:
  642. return []
  643. return asr_result_list
  644. return _forward
  645. def get_parser():
  646. parser = config_argparse.ArgumentParser(
  647. description="ASR Decoding",
  648. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  649. )
  650. # Note(kamo): Use '_' instead of '-' as separator.
  651. # '-' is confusing if written in yaml.
  652. parser.add_argument(
  653. "--log_level",
  654. type=lambda x: x.upper(),
  655. default="INFO",
  656. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  657. help="The verbose level of logging",
  658. )
  659. parser.add_argument("--output_dir", type=str, required=True)
  660. parser.add_argument(
  661. "--ngpu",
  662. type=int,
  663. default=0,
  664. help="The number of gpus. 0 indicates CPU mode",
  665. )
  666. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  667. parser.add_argument(
  668. "--dtype",
  669. default="float32",
  670. choices=["float16", "float32", "float64"],
  671. help="Data type",
  672. )
  673. parser.add_argument(
  674. "--num_workers",
  675. type=int,
  676. default=1,
  677. help="The number of workers used for DataLoader",
  678. )
  679. parser.add_argument(
  680. "--hotword",
  681. type=str_or_none,
  682. default=None,
  683. help="hotword file path or hotwords seperated by space"
  684. )
  685. group = parser.add_argument_group("Input data related")
  686. group.add_argument(
  687. "--data_path_and_name_and_type",
  688. type=str2triple_str,
  689. required=False,
  690. action="append",
  691. )
  692. group.add_argument("--key_file", type=str_or_none)
  693. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  694. group = parser.add_argument_group("The model configuration related")
  695. group.add_argument(
  696. "--asr_train_config",
  697. type=str,
  698. help="ASR training configuration",
  699. )
  700. group.add_argument(
  701. "--asr_model_file",
  702. type=str,
  703. help="ASR model parameter file",
  704. )
  705. group.add_argument(
  706. "--cmvn_file",
  707. type=str,
  708. help="Global cmvn file",
  709. )
  710. group.add_argument(
  711. "--lm_train_config",
  712. type=str,
  713. help="LM training configuration",
  714. )
  715. group.add_argument(
  716. "--lm_file",
  717. type=str,
  718. help="LM parameter file",
  719. )
  720. group.add_argument(
  721. "--word_lm_train_config",
  722. type=str,
  723. help="Word LM training configuration",
  724. )
  725. group.add_argument(
  726. "--word_lm_file",
  727. type=str,
  728. help="Word LM parameter file",
  729. )
  730. group.add_argument(
  731. "--ngram_file",
  732. type=str,
  733. help="N-gram parameter file",
  734. )
  735. group.add_argument(
  736. "--model_tag",
  737. type=str,
  738. help="Pretrained model tag. If specify this option, *_train_config and "
  739. "*_file will be overwritten",
  740. )
  741. group = parser.add_argument_group("Beam-search related")
  742. group.add_argument(
  743. "--batch_size",
  744. type=int,
  745. default=1,
  746. help="The batch size for inference",
  747. )
  748. group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
  749. group.add_argument("--beam_size", type=int, default=20, help="Beam size")
  750. group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
  751. group.add_argument(
  752. "--maxlenratio",
  753. type=float,
  754. default=0.0,
  755. help="Input length ratio to obtain max output length. "
  756. "If maxlenratio=0.0 (default), it uses a end-detect "
  757. "function "
  758. "to automatically find maximum hypothesis lengths."
  759. "If maxlenratio<0.0, its absolute value is interpreted"
  760. "as a constant max output length",
  761. )
  762. group.add_argument(
  763. "--minlenratio",
  764. type=float,
  765. default=0.0,
  766. help="Input length ratio to obtain min output length",
  767. )
  768. group.add_argument(
  769. "--ctc_weight",
  770. type=float,
  771. default=0.5,
  772. help="CTC weight in joint decoding",
  773. )
  774. group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
  775. group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
  776. group.add_argument("--streaming", type=str2bool, default=False)
  777. group.add_argument(
  778. "--frontend_conf",
  779. default=None,
  780. help="",
  781. )
  782. group.add_argument("--raw_inputs", type=list, default=None)
  783. # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
  784. group = parser.add_argument_group("Text converter related")
  785. group.add_argument(
  786. "--token_type",
  787. type=str_or_none,
  788. default=None,
  789. choices=["char", "bpe", None],
  790. help="The token type for ASR model. "
  791. "If not given, refers from the training args",
  792. )
  793. group.add_argument(
  794. "--bpemodel",
  795. type=str_or_none,
  796. default=None,
  797. help="The model path of sentencepiece. "
  798. "If not given, refers from the training args",
  799. )
  800. return parser
  801. def main(cmd=None):
  802. print(get_commandline_args(), file=sys.stderr)
  803. parser = get_parser()
  804. args = parser.parse_args(cmd)
  805. param_dict = {'hotword': args.hotword}
  806. kwargs = vars(args)
  807. kwargs.pop("config", None)
  808. kwargs['param_dict'] = param_dict
  809. inference(**kwargs)
  810. if __name__ == "__main__":
  811. main()
  812. # from modelscope.pipelines import pipeline
  813. # from modelscope.utils.constant import Tasks
  814. #
  815. # inference_16k_pipline = pipeline(
  816. # task=Tasks.auto_speech_recognition,
  817. # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
  818. #
  819. # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
  820. # print(rec_result)