model.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import logging
  8. from typing import Dict, Tuple
  9. from contextlib import contextmanager
  10. from distutils.version import LooseVersion
  11. from funasr.register import tables
  12. from funasr.models.ctc.ctc import CTC
  13. from funasr.utils import postprocess_utils
  14. from funasr.metrics.compute_acc import th_accuracy
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.models.paraformer.model import Paraformer
  17. from funasr.models.paraformer.search import Hypothesis
  18. from funasr.models.paraformer.cif_predictor import mae_loss
  19. from funasr.train_utils.device_funcs import force_gatherable
  20. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  21. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  22. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  23. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  24. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  25. from torch.cuda.amp import autocast
  26. else:
  27. # Nothing to do if torch<1.6.0
  28. @contextmanager
  29. def autocast(enabled=True):
  30. yield
  31. @tables.register("model_classes", "ParaformerStreaming")
  32. class ParaformerStreaming(Paraformer):
  33. """
  34. Author: Speech Lab of DAMO Academy, Alibaba Group
  35. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  36. https://arxiv.org/abs/2206.08317
  37. """
  38. def __init__(
  39. self,
  40. *args,
  41. **kwargs,
  42. ):
  43. super().__init__(*args, **kwargs)
  44. # import pdb;
  45. # pdb.set_trace()
  46. self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
  47. self.scama_mask = None
  48. if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
  49. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  50. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  51. self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
  52. def forward(
  53. self,
  54. speech: torch.Tensor,
  55. speech_lengths: torch.Tensor,
  56. text: torch.Tensor,
  57. text_lengths: torch.Tensor,
  58. **kwargs,
  59. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  60. """Encoder + Decoder + Calc loss
  61. Args:
  62. speech: (Batch, Length, ...)
  63. speech_lengths: (Batch, )
  64. text: (Batch, Length)
  65. text_lengths: (Batch,)
  66. """
  67. # import pdb;
  68. # pdb.set_trace()
  69. decoding_ind = kwargs.get("decoding_ind")
  70. if len(text_lengths.size()) > 1:
  71. text_lengths = text_lengths[:, 0]
  72. if len(speech_lengths.size()) > 1:
  73. speech_lengths = speech_lengths[:, 0]
  74. batch_size = speech.shape[0]
  75. # Encoder
  76. if hasattr(self.encoder, "overlap_chunk_cls"):
  77. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  78. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  79. else:
  80. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  81. loss_ctc, cer_ctc = None, None
  82. loss_pre = None
  83. stats = dict()
  84. # decoder: CTC branch
  85. if self.ctc_weight > 0.0:
  86. if hasattr(self.encoder, "overlap_chunk_cls"):
  87. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  88. encoder_out_lens,
  89. chunk_outs=None)
  90. else:
  91. encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
  92. loss_ctc, cer_ctc = self._calc_ctc_loss(
  93. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  94. )
  95. # Collect CTC branch stats
  96. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  97. stats["cer_ctc"] = cer_ctc
  98. # decoder: Attention decoder branch
  99. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
  100. encoder_out, encoder_out_lens, text, text_lengths
  101. )
  102. # 3. CTC-Att loss definition
  103. if self.ctc_weight == 0.0:
  104. loss = loss_att + loss_pre * self.predictor_weight
  105. else:
  106. loss = self.ctc_weight * loss_ctc + (
  107. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  108. # Collect Attn branch stats
  109. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  110. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  111. stats["acc"] = acc_att
  112. stats["cer"] = cer_att
  113. stats["wer"] = wer_att
  114. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  115. stats["loss"] = torch.clone(loss.detach())
  116. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  117. if self.length_normalized_loss:
  118. batch_size = (text_lengths + self.predictor_bias).sum()
  119. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  120. return loss, stats, weight
  121. def encode_chunk(
  122. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
  123. ) -> Tuple[torch.Tensor, torch.Tensor]:
  124. """Frontend + Encoder. Note that this method is used by asr_inference.py
  125. Args:
  126. speech: (Batch, Length, ...)
  127. speech_lengths: (Batch, )
  128. ind: int
  129. """
  130. with autocast(False):
  131. # Data augmentation
  132. if self.specaug is not None and self.training:
  133. speech, speech_lengths = self.specaug(speech, speech_lengths)
  134. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  135. if self.normalize is not None:
  136. speech, speech_lengths = self.normalize(speech, speech_lengths)
  137. # Forward encoder
  138. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
  139. if isinstance(encoder_out, tuple):
  140. encoder_out = encoder_out[0]
  141. return encoder_out, torch.tensor([encoder_out.size(1)])
  142. def _calc_att_predictor_loss(
  143. self,
  144. encoder_out: torch.Tensor,
  145. encoder_out_lens: torch.Tensor,
  146. ys_pad: torch.Tensor,
  147. ys_pad_lens: torch.Tensor,
  148. ):
  149. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  150. encoder_out.device)
  151. if self.predictor_bias == 1:
  152. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  153. ys_pad_lens = ys_pad_lens + self.predictor_bias
  154. mask_chunk_predictor = None
  155. if self.encoder.overlap_chunk_cls is not None:
  156. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  157. device=encoder_out.device,
  158. batch_size=encoder_out.size(
  159. 0))
  160. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  161. batch_size=encoder_out.size(0))
  162. encoder_out = encoder_out * mask_shfit_chunk
  163. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  164. ys_pad,
  165. encoder_out_mask,
  166. ignore_id=self.ignore_id,
  167. mask_chunk_predictor=mask_chunk_predictor,
  168. target_label_length=ys_pad_lens,
  169. )
  170. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  171. encoder_out_lens)
  172. scama_mask = None
  173. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  174. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  175. attention_chunk_center_bias = 0
  176. attention_chunk_size = encoder_chunk_size
  177. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  178. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  179. get_mask_shift_att_chunk_decoder(None,
  180. device=encoder_out.device,
  181. batch_size=encoder_out.size(0)
  182. )
  183. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  184. predictor_alignments=predictor_alignments,
  185. encoder_sequence_length=encoder_out_lens,
  186. chunk_size=1,
  187. encoder_chunk_size=encoder_chunk_size,
  188. attention_chunk_center_bias=attention_chunk_center_bias,
  189. attention_chunk_size=attention_chunk_size,
  190. attention_chunk_type=self.decoder_attention_chunk_type,
  191. step=None,
  192. predictor_mask_chunk_hopping=mask_chunk_predictor,
  193. decoder_att_look_back_factor=decoder_att_look_back_factor,
  194. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  195. target_length=ys_pad_lens,
  196. is_training=self.training,
  197. )
  198. elif self.encoder.overlap_chunk_cls is not None:
  199. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  200. encoder_out_lens,
  201. chunk_outs=None)
  202. # 0. sampler
  203. decoder_out_1st = None
  204. pre_loss_att = None
  205. if self.sampling_ratio > 0.0:
  206. if self.use_1st_decoder_loss:
  207. sematic_embeds, decoder_out_1st, pre_loss_att = \
  208. self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
  209. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  210. else:
  211. sematic_embeds, decoder_out_1st = \
  212. self.sampler(encoder_out, encoder_out_lens, ys_pad,
  213. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  214. else:
  215. sematic_embeds = pre_acoustic_embeds
  216. # 1. Forward decoder
  217. decoder_outs = self.decoder(
  218. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
  219. )
  220. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  221. if decoder_out_1st is None:
  222. decoder_out_1st = decoder_out
  223. # 2. Compute attention loss
  224. loss_att = self.criterion_att(decoder_out, ys_pad)
  225. acc_att = th_accuracy(
  226. decoder_out_1st.view(-1, self.vocab_size),
  227. ys_pad,
  228. ignore_label=self.ignore_id,
  229. )
  230. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  231. # Compute cer/wer using attention-decoder
  232. if self.training or self.error_calculator is None:
  233. cer_att, wer_att = None, None
  234. else:
  235. ys_hat = decoder_out_1st.argmax(dim=-1)
  236. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  237. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  238. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
  239. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  240. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  241. if self.share_embedding:
  242. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  243. else:
  244. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  245. with torch.no_grad():
  246. decoder_outs = self.decoder(
  247. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
  248. )
  249. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  250. pred_tokens = decoder_out.argmax(-1)
  251. nonpad_positions = ys_pad.ne(self.ignore_id)
  252. seq_lens = (nonpad_positions).sum(1)
  253. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  254. input_mask = torch.ones_like(nonpad_positions)
  255. bsz, seq_len = ys_pad.size()
  256. for li in range(bsz):
  257. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  258. if target_num > 0:
  259. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  260. input_mask = input_mask.eq(1)
  261. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  262. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  263. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  264. input_mask_expand_dim, 0)
  265. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  266. def calc_predictor(self, encoder_out, encoder_out_lens):
  267. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  268. encoder_out.device)
  269. mask_chunk_predictor = None
  270. if self.encoder.overlap_chunk_cls is not None:
  271. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  272. device=encoder_out.device,
  273. batch_size=encoder_out.size(
  274. 0))
  275. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  276. batch_size=encoder_out.size(0))
  277. encoder_out = encoder_out * mask_shfit_chunk
  278. pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
  279. None,
  280. encoder_out_mask,
  281. ignore_id=self.ignore_id,
  282. mask_chunk_predictor=mask_chunk_predictor,
  283. target_label_length=None,
  284. )
  285. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  286. encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
  287. scama_mask = None
  288. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  289. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  290. attention_chunk_center_bias = 0
  291. attention_chunk_size = encoder_chunk_size
  292. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  293. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  294. get_mask_shift_att_chunk_decoder(None,
  295. device=encoder_out.device,
  296. batch_size=encoder_out.size(0)
  297. )
  298. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  299. predictor_alignments=predictor_alignments,
  300. encoder_sequence_length=encoder_out_lens,
  301. chunk_size=1,
  302. encoder_chunk_size=encoder_chunk_size,
  303. attention_chunk_center_bias=attention_chunk_center_bias,
  304. attention_chunk_size=attention_chunk_size,
  305. attention_chunk_type=self.decoder_attention_chunk_type,
  306. step=None,
  307. predictor_mask_chunk_hopping=mask_chunk_predictor,
  308. decoder_att_look_back_factor=decoder_att_look_back_factor,
  309. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  310. target_length=None,
  311. is_training=self.training,
  312. )
  313. self.scama_mask = scama_mask
  314. return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
  315. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
  316. is_final = kwargs.get("is_final", False)
  317. return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
  318. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  319. decoder_outs = self.decoder(
  320. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
  321. )
  322. decoder_out = decoder_outs[0]
  323. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  324. return decoder_out, ys_pad_lens
  325. def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
  326. decoder_outs = self.decoder.forward_chunk(
  327. encoder_out, sematic_embeds, cache["decoder"]
  328. )
  329. decoder_out = decoder_outs
  330. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  331. return decoder_out, ys_pad_lens
  332. def init_cache(self, cache: dict = {}, **kwargs):
  333. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  334. encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
  335. decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
  336. batch_size = 1
  337. enc_output_size = kwargs["encoder_conf"]["output_size"]
  338. feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
  339. cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  340. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
  341. "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
  342. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  343. "tail_chunk": False}
  344. cache["encoder"] = cache_encoder
  345. cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
  346. "chunk_size": chunk_size}
  347. cache["decoder"] = cache_decoder
  348. cache["frontend"] = {}
  349. cache["prev_samples"] = torch.empty(0)
  350. return cache
  351. def generate_chunk(self,
  352. speech,
  353. speech_lengths=None,
  354. key: list = None,
  355. tokenizer=None,
  356. frontend=None,
  357. **kwargs,
  358. ):
  359. cache = kwargs.get("cache", {})
  360. speech = speech.to(device=kwargs["device"])
  361. speech_lengths = speech_lengths.to(device=kwargs["device"])
  362. # Encoder
  363. encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
  364. if isinstance(encoder_out, tuple):
  365. encoder_out = encoder_out[0]
  366. # predictor
  367. predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
  368. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  369. predictor_outs[2], predictor_outs[3]
  370. pre_token_length = pre_token_length.round().long()
  371. if torch.max(pre_token_length) < 1:
  372. return []
  373. decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
  374. encoder_out_lens,
  375. pre_acoustic_embeds,
  376. pre_token_length,
  377. cache=cache
  378. )
  379. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  380. results = []
  381. b, n, d = decoder_out.size()
  382. if isinstance(key[0], (list, tuple)):
  383. key = key[0]
  384. for i in range(b):
  385. x = encoder_out[i, :encoder_out_lens[i], :]
  386. am_scores = decoder_out[i, :pre_token_length[i], :]
  387. if self.beam_search is not None:
  388. nbest_hyps = self.beam_search(
  389. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  390. minlenratio=kwargs.get("minlenratio", 0.0)
  391. )
  392. nbest_hyps = nbest_hyps[: self.nbest]
  393. else:
  394. yseq = am_scores.argmax(dim=-1)
  395. score = am_scores.max(dim=-1)[0]
  396. score = torch.sum(score, dim=-1)
  397. # pad with mask tokens to ensure compatibility with sos/eos tokens
  398. yseq = torch.tensor(
  399. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  400. )
  401. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  402. for nbest_idx, hyp in enumerate(nbest_hyps):
  403. # remove sos/eos and get results
  404. last_pos = -1
  405. if isinstance(hyp.yseq, list):
  406. token_int = hyp.yseq[1:last_pos]
  407. else:
  408. token_int = hyp.yseq[1:last_pos].tolist()
  409. # remove blank symbol id, which is assumed to be 0
  410. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  411. # Change integer-ids to tokens
  412. token = tokenizer.ids2tokens(token_int)
  413. # text = tokenizer.tokens2text(token)
  414. result_i = token
  415. results.extend(result_i)
  416. return results
  417. def inference(self,
  418. data_in,
  419. data_lengths=None,
  420. key: list = None,
  421. tokenizer=None,
  422. frontend=None,
  423. cache: dict={},
  424. **kwargs,
  425. ):
  426. # init beamsearch
  427. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  428. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  429. if self.beam_search is None and (is_use_lm or is_use_ctc):
  430. logging.info("enable beam_search")
  431. self.init_beam_search(**kwargs)
  432. self.nbest = kwargs.get("nbest", 1)
  433. if len(cache) == 0:
  434. self.init_cache(cache, **kwargs)
  435. meta_data = {}
  436. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  437. chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
  438. time1 = time.perf_counter()
  439. cfg = {"is_final": kwargs.get("is_final", False)}
  440. audio_sample_list = load_audio_text_image_video(data_in,
  441. fs=frontend.fs,
  442. audio_fs=kwargs.get("fs", 16000),
  443. data_type=kwargs.get("data_type", "sound"),
  444. tokenizer=tokenizer,
  445. cache=cfg,
  446. )
  447. _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
  448. time2 = time.perf_counter()
  449. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  450. assert len(audio_sample_list) == 1, "batch_size must be set 1"
  451. audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
  452. n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
  453. m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
  454. tokens = []
  455. for i in range(n):
  456. kwargs["is_final"] = _is_final and i == n -1
  457. audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
  458. # extract fbank feats
  459. speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
  460. frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
  461. time3 = time.perf_counter()
  462. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  463. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  464. tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
  465. tokens.extend(tokens_i)
  466. text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
  467. result_i = {"key": key[0], "text": text_postprocessed}
  468. result = [result_i]
  469. cache["prev_samples"] = audio_sample[:-m]
  470. if _is_final:
  471. self.init_cache(cache, **kwargs)
  472. if kwargs.get("output_dir"):
  473. if not hasattr(self, "writer"):
  474. self.writer = DatadirWriter(kwargs.get("output_dir"))
  475. ibest_writer = self.writer[f"{1}best_recog"]
  476. ibest_writer["token"][key[0]] = " ".join(tokens)
  477. ibest_writer["text"][key[0]] = text_postprocessed
  478. return result, meta_data