model.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import logging
  8. from typing import Dict, Tuple
  9. from contextlib import contextmanager
  10. from distutils.version import LooseVersion
  11. from funasr.register import tables
  12. from funasr.models.ctc.ctc import CTC
  13. from funasr.utils import postprocess_utils
  14. from funasr.metrics.compute_acc import th_accuracy
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.models.paraformer.model import Paraformer
  17. from funasr.models.paraformer.search import Hypothesis
  18. from funasr.models.paraformer.cif_predictor import mae_loss
  19. from funasr.train_utils.device_funcs import force_gatherable
  20. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  21. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  22. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  23. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  24. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  25. from torch.cuda.amp import autocast
  26. else:
  27. # Nothing to do if torch<1.6.0
  28. @contextmanager
  29. def autocast(enabled=True):
  30. yield
  31. @tables.register("model_classes", "ParaformerStreaming")
  32. class ParaformerStreaming(Paraformer):
  33. """
  34. Author: Speech Lab of DAMO Academy, Alibaba Group
  35. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  36. https://arxiv.org/abs/2206.08317
  37. """
  38. def __init__(
  39. self,
  40. *args,
  41. **kwargs,
  42. ):
  43. super().__init__(*args, **kwargs)
  44. # import pdb;
  45. # pdb.set_trace()
  46. self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
  47. self.scama_mask = None
  48. if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
  49. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  50. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  51. self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
  52. def forward(
  53. self,
  54. speech: torch.Tensor,
  55. speech_lengths: torch.Tensor,
  56. text: torch.Tensor,
  57. text_lengths: torch.Tensor,
  58. **kwargs,
  59. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  60. """Encoder + Decoder + Calc loss
  61. Args:
  62. speech: (Batch, Length, ...)
  63. speech_lengths: (Batch, )
  64. text: (Batch, Length)
  65. text_lengths: (Batch,)
  66. """
  67. # import pdb;
  68. # pdb.set_trace()
  69. decoding_ind = kwargs.get("decoding_ind")
  70. if len(text_lengths.size()) > 1:
  71. text_lengths = text_lengths[:, 0]
  72. if len(speech_lengths.size()) > 1:
  73. speech_lengths = speech_lengths[:, 0]
  74. batch_size = speech.shape[0]
  75. # Encoder
  76. if hasattr(self.encoder, "overlap_chunk_cls"):
  77. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  78. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  79. else:
  80. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  81. loss_ctc, cer_ctc = None, None
  82. loss_pre = None
  83. stats = dict()
  84. # decoder: CTC branch
  85. if self.ctc_weight > 0.0:
  86. if hasattr(self.encoder, "overlap_chunk_cls"):
  87. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  88. encoder_out_lens,
  89. chunk_outs=None)
  90. else:
  91. encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
  92. loss_ctc, cer_ctc = self._calc_ctc_loss(
  93. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  94. )
  95. # Collect CTC branch stats
  96. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  97. stats["cer_ctc"] = cer_ctc
  98. # decoder: Attention decoder branch
  99. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
  100. encoder_out, encoder_out_lens, text, text_lengths
  101. )
  102. # 3. CTC-Att loss definition
  103. if self.ctc_weight == 0.0:
  104. loss = loss_att + loss_pre * self.predictor_weight
  105. else:
  106. loss = self.ctc_weight * loss_ctc + (
  107. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  108. # Collect Attn branch stats
  109. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  110. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  111. stats["acc"] = acc_att
  112. stats["cer"] = cer_att
  113. stats["wer"] = wer_att
  114. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  115. stats["loss"] = torch.clone(loss.detach())
  116. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  117. if self.length_normalized_loss:
  118. batch_size = (text_lengths + self.predictor_bias).sum()
  119. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  120. return loss, stats, weight
  121. def encode_chunk(
  122. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
  123. ) -> Tuple[torch.Tensor, torch.Tensor]:
  124. """Frontend + Encoder. Note that this method is used by asr_inference.py
  125. Args:
  126. speech: (Batch, Length, ...)
  127. speech_lengths: (Batch, )
  128. ind: int
  129. """
  130. with autocast(False):
  131. # Data augmentation
  132. if self.specaug is not None and self.training:
  133. speech, speech_lengths = self.specaug(speech, speech_lengths)
  134. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  135. if self.normalize is not None:
  136. speech, speech_lengths = self.normalize(speech, speech_lengths)
  137. # Forward encoder
  138. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
  139. if isinstance(encoder_out, tuple):
  140. encoder_out = encoder_out[0]
  141. return encoder_out, torch.tensor([encoder_out.size(1)])
  142. def _calc_att_predictor_loss(
  143. self,
  144. encoder_out: torch.Tensor,
  145. encoder_out_lens: torch.Tensor,
  146. ys_pad: torch.Tensor,
  147. ys_pad_lens: torch.Tensor,
  148. ):
  149. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  150. encoder_out.device)
  151. if self.predictor_bias == 1:
  152. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  153. ys_pad_lens = ys_pad_lens + self.predictor_bias
  154. mask_chunk_predictor = None
  155. if self.encoder.overlap_chunk_cls is not None:
  156. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  157. device=encoder_out.device,
  158. batch_size=encoder_out.size(
  159. 0))
  160. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  161. batch_size=encoder_out.size(0))
  162. encoder_out = encoder_out * mask_shfit_chunk
  163. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  164. ys_pad,
  165. encoder_out_mask,
  166. ignore_id=self.ignore_id,
  167. mask_chunk_predictor=mask_chunk_predictor,
  168. target_label_length=ys_pad_lens,
  169. )
  170. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  171. encoder_out_lens)
  172. scama_mask = None
  173. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  174. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  175. attention_chunk_center_bias = 0
  176. attention_chunk_size = encoder_chunk_size
  177. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  178. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  179. get_mask_shift_att_chunk_decoder(None,
  180. device=encoder_out.device,
  181. batch_size=encoder_out.size(0)
  182. )
  183. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  184. predictor_alignments=predictor_alignments,
  185. encoder_sequence_length=encoder_out_lens,
  186. chunk_size=1,
  187. encoder_chunk_size=encoder_chunk_size,
  188. attention_chunk_center_bias=attention_chunk_center_bias,
  189. attention_chunk_size=attention_chunk_size,
  190. attention_chunk_type=self.decoder_attention_chunk_type,
  191. step=None,
  192. predictor_mask_chunk_hopping=mask_chunk_predictor,
  193. decoder_att_look_back_factor=decoder_att_look_back_factor,
  194. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  195. target_length=ys_pad_lens,
  196. is_training=self.training,
  197. )
  198. elif self.encoder.overlap_chunk_cls is not None:
  199. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  200. encoder_out_lens,
  201. chunk_outs=None)
  202. # 0. sampler
  203. decoder_out_1st = None
  204. pre_loss_att = None
  205. if self.sampling_ratio > 0.0:
  206. if self.step_cur < 2:
  207. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  208. if self.use_1st_decoder_loss:
  209. sematic_embeds, decoder_out_1st, pre_loss_att = \
  210. self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
  211. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  212. else:
  213. sematic_embeds, decoder_out_1st = \
  214. self.sampler(encoder_out, encoder_out_lens, ys_pad,
  215. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  216. else:
  217. if self.step_cur < 2:
  218. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  219. sematic_embeds = pre_acoustic_embeds
  220. # 1. Forward decoder
  221. decoder_outs = self.decoder(
  222. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
  223. )
  224. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  225. if decoder_out_1st is None:
  226. decoder_out_1st = decoder_out
  227. # 2. Compute attention loss
  228. loss_att = self.criterion_att(decoder_out, ys_pad)
  229. acc_att = th_accuracy(
  230. decoder_out_1st.view(-1, self.vocab_size),
  231. ys_pad,
  232. ignore_label=self.ignore_id,
  233. )
  234. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  235. # Compute cer/wer using attention-decoder
  236. if self.training or self.error_calculator is None:
  237. cer_att, wer_att = None, None
  238. else:
  239. ys_hat = decoder_out_1st.argmax(dim=-1)
  240. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  241. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  242. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
  243. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  244. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  245. if self.share_embedding:
  246. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  247. else:
  248. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  249. with torch.no_grad():
  250. decoder_outs = self.decoder(
  251. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
  252. )
  253. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  254. pred_tokens = decoder_out.argmax(-1)
  255. nonpad_positions = ys_pad.ne(self.ignore_id)
  256. seq_lens = (nonpad_positions).sum(1)
  257. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  258. input_mask = torch.ones_like(nonpad_positions)
  259. bsz, seq_len = ys_pad.size()
  260. for li in range(bsz):
  261. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  262. if target_num > 0:
  263. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  264. input_mask = input_mask.eq(1)
  265. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  266. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  267. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  268. input_mask_expand_dim, 0)
  269. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  270. def calc_predictor(self, encoder_out, encoder_out_lens):
  271. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  272. encoder_out.device)
  273. mask_chunk_predictor = None
  274. if self.encoder.overlap_chunk_cls is not None:
  275. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  276. device=encoder_out.device,
  277. batch_size=encoder_out.size(
  278. 0))
  279. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  280. batch_size=encoder_out.size(0))
  281. encoder_out = encoder_out * mask_shfit_chunk
  282. pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
  283. None,
  284. encoder_out_mask,
  285. ignore_id=self.ignore_id,
  286. mask_chunk_predictor=mask_chunk_predictor,
  287. target_label_length=None,
  288. )
  289. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  290. encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
  291. scama_mask = None
  292. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  293. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  294. attention_chunk_center_bias = 0
  295. attention_chunk_size = encoder_chunk_size
  296. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  297. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  298. get_mask_shift_att_chunk_decoder(None,
  299. device=encoder_out.device,
  300. batch_size=encoder_out.size(0)
  301. )
  302. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  303. predictor_alignments=predictor_alignments,
  304. encoder_sequence_length=encoder_out_lens,
  305. chunk_size=1,
  306. encoder_chunk_size=encoder_chunk_size,
  307. attention_chunk_center_bias=attention_chunk_center_bias,
  308. attention_chunk_size=attention_chunk_size,
  309. attention_chunk_type=self.decoder_attention_chunk_type,
  310. step=None,
  311. predictor_mask_chunk_hopping=mask_chunk_predictor,
  312. decoder_att_look_back_factor=decoder_att_look_back_factor,
  313. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  314. target_length=None,
  315. is_training=self.training,
  316. )
  317. self.scama_mask = scama_mask
  318. return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
  319. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
  320. is_final = kwargs.get("is_final", False)
  321. return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
  322. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  323. decoder_outs = self.decoder(
  324. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
  325. )
  326. decoder_out = decoder_outs[0]
  327. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  328. return decoder_out, ys_pad_lens
  329. def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
  330. decoder_outs = self.decoder.forward_chunk(
  331. encoder_out, sematic_embeds, cache["decoder"]
  332. )
  333. decoder_out = decoder_outs
  334. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  335. return decoder_out, ys_pad_lens
  336. def init_cache(self, cache: dict = {}, **kwargs):
  337. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  338. encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
  339. decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
  340. batch_size = 1
  341. enc_output_size = kwargs["encoder_conf"]["output_size"]
  342. feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
  343. cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  344. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
  345. "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
  346. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  347. "tail_chunk": False}
  348. cache["encoder"] = cache_encoder
  349. cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
  350. "chunk_size": chunk_size}
  351. cache["decoder"] = cache_decoder
  352. cache["frontend"] = {}
  353. cache["prev_samples"] = torch.empty(0)
  354. return cache
  355. def generate_chunk(self,
  356. speech,
  357. speech_lengths=None,
  358. key: list = None,
  359. tokenizer=None,
  360. frontend=None,
  361. **kwargs,
  362. ):
  363. cache = kwargs.get("cache", {})
  364. speech = speech.to(device=kwargs["device"])
  365. speech_lengths = speech_lengths.to(device=kwargs["device"])
  366. # Encoder
  367. encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
  368. if isinstance(encoder_out, tuple):
  369. encoder_out = encoder_out[0]
  370. # predictor
  371. predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
  372. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  373. predictor_outs[2], predictor_outs[3]
  374. pre_token_length = pre_token_length.round().long()
  375. if torch.max(pre_token_length) < 1:
  376. return []
  377. decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
  378. encoder_out_lens,
  379. pre_acoustic_embeds,
  380. pre_token_length,
  381. cache=cache
  382. )
  383. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  384. results = []
  385. b, n, d = decoder_out.size()
  386. if isinstance(key[0], (list, tuple)):
  387. key = key[0]
  388. for i in range(b):
  389. x = encoder_out[i, :encoder_out_lens[i], :]
  390. am_scores = decoder_out[i, :pre_token_length[i], :]
  391. if self.beam_search is not None:
  392. nbest_hyps = self.beam_search(
  393. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  394. minlenratio=kwargs.get("minlenratio", 0.0)
  395. )
  396. nbest_hyps = nbest_hyps[: self.nbest]
  397. else:
  398. yseq = am_scores.argmax(dim=-1)
  399. score = am_scores.max(dim=-1)[0]
  400. score = torch.sum(score, dim=-1)
  401. # pad with mask tokens to ensure compatibility with sos/eos tokens
  402. yseq = torch.tensor(
  403. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  404. )
  405. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  406. for nbest_idx, hyp in enumerate(nbest_hyps):
  407. # remove sos/eos and get results
  408. last_pos = -1
  409. if isinstance(hyp.yseq, list):
  410. token_int = hyp.yseq[1:last_pos]
  411. else:
  412. token_int = hyp.yseq[1:last_pos].tolist()
  413. # remove blank symbol id, which is assumed to be 0
  414. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  415. # Change integer-ids to tokens
  416. token = tokenizer.ids2tokens(token_int)
  417. # text = tokenizer.tokens2text(token)
  418. result_i = token
  419. results.extend(result_i)
  420. return results
  421. def inference(self,
  422. data_in,
  423. data_lengths=None,
  424. key: list = None,
  425. tokenizer=None,
  426. frontend=None,
  427. cache: dict={},
  428. **kwargs,
  429. ):
  430. # init beamsearch
  431. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  432. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  433. if self.beam_search is None and (is_use_lm or is_use_ctc):
  434. logging.info("enable beam_search")
  435. self.init_beam_search(**kwargs)
  436. self.nbest = kwargs.get("nbest", 1)
  437. if len(cache) == 0:
  438. self.init_cache(cache, **kwargs)
  439. meta_data = {}
  440. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  441. chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
  442. time1 = time.perf_counter()
  443. cfg = {"is_final": kwargs.get("is_final", False)}
  444. audio_sample_list = load_audio_text_image_video(data_in,
  445. fs=frontend.fs,
  446. audio_fs=kwargs.get("fs", 16000),
  447. data_type=kwargs.get("data_type", "sound"),
  448. tokenizer=tokenizer,
  449. cache=cfg,
  450. )
  451. _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
  452. time2 = time.perf_counter()
  453. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  454. assert len(audio_sample_list) == 1, "batch_size must be set 1"
  455. audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
  456. n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
  457. m = int(len(audio_sample) % chunk_stride_samples * (1-int(_is_final)))
  458. tokens = []
  459. for i in range(n):
  460. kwargs["is_final"] = _is_final and i == n -1
  461. audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
  462. # extract fbank feats
  463. speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
  464. frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
  465. time3 = time.perf_counter()
  466. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  467. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  468. tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
  469. tokens.extend(tokens_i)
  470. text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
  471. result_i = {"key": key[0], "text": text_postprocessed}
  472. result = [result_i]
  473. cache["prev_samples"] = audio_sample[:-m]
  474. if _is_final:
  475. self.init_cache(cache, **kwargs)
  476. if kwargs.get("output_dir"):
  477. writer = DatadirWriter(kwargs.get("output_dir"))
  478. ibest_writer = writer[f"{1}best_recog"]
  479. ibest_writer["token"][key[0]] = " ".join(tokens)
  480. ibest_writer["text"][key[0]] = text_postprocessed
  481. return result, meta_data