model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. import os
  2. import logging
  3. from contextlib import contextmanager
  4. from distutils.version import LooseVersion
  5. from typing import Dict
  6. from typing import List
  7. from typing import Optional
  8. from typing import Tuple
  9. from typing import Union
  10. import tempfile
  11. import codecs
  12. import requests
  13. import re
  14. import copy
  15. import torch
  16. import torch.nn as nn
  17. import random
  18. import numpy as np
  19. import time
  20. # from funasr.layers.abs_normalize import AbsNormalize
  21. from funasr.losses.label_smoothing_loss import (
  22. LabelSmoothingLoss, # noqa: H301
  23. )
  24. from funasr.models.paraformer.cif_predictor import mae_loss
  25. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  26. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  27. from funasr.metrics.compute_acc import th_accuracy
  28. from funasr.train_utils.device_funcs import force_gatherable
  29. from funasr.models.paraformer.search import Hypothesis
  30. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  31. from torch.cuda.amp import autocast
  32. else:
  33. # Nothing to do if torch<1.6.0
  34. @contextmanager
  35. def autocast(enabled=True):
  36. yield
  37. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  38. from funasr.utils import postprocess_utils
  39. from funasr.utils.datadir_writer import DatadirWriter
  40. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  41. from funasr.models.ctc.ctc import CTC
  42. from funasr.models.paraformer.model import Paraformer
  43. from funasr.register import tables
  44. @tables.register("model_classes", "ParaformerStreaming")
  45. class ParaformerStreaming(Paraformer):
  46. """
  47. Author: Speech Lab of DAMO Academy, Alibaba Group
  48. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  49. https://arxiv.org/abs/2206.08317
  50. """
  51. def __init__(
  52. self,
  53. *args,
  54. **kwargs,
  55. ):
  56. super().__init__(*args, **kwargs)
  57. # import pdb;
  58. # pdb.set_trace()
  59. self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
  60. self.scama_mask = None
  61. if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
  62. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  63. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  64. self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
  65. def forward(
  66. self,
  67. speech: torch.Tensor,
  68. speech_lengths: torch.Tensor,
  69. text: torch.Tensor,
  70. text_lengths: torch.Tensor,
  71. **kwargs,
  72. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  73. """Encoder + Decoder + Calc loss
  74. Args:
  75. speech: (Batch, Length, ...)
  76. speech_lengths: (Batch, )
  77. text: (Batch, Length)
  78. text_lengths: (Batch,)
  79. """
  80. # import pdb;
  81. # pdb.set_trace()
  82. decoding_ind = kwargs.get("decoding_ind")
  83. if len(text_lengths.size()) > 1:
  84. text_lengths = text_lengths[:, 0]
  85. if len(speech_lengths.size()) > 1:
  86. speech_lengths = speech_lengths[:, 0]
  87. batch_size = speech.shape[0]
  88. # Encoder
  89. if hasattr(self.encoder, "overlap_chunk_cls"):
  90. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  91. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  92. else:
  93. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  94. loss_ctc, cer_ctc = None, None
  95. loss_pre = None
  96. stats = dict()
  97. # decoder: CTC branch
  98. if self.ctc_weight > 0.0:
  99. if hasattr(self.encoder, "overlap_chunk_cls"):
  100. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  101. encoder_out_lens,
  102. chunk_outs=None)
  103. else:
  104. encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
  105. loss_ctc, cer_ctc = self._calc_ctc_loss(
  106. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  107. )
  108. # Collect CTC branch stats
  109. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  110. stats["cer_ctc"] = cer_ctc
  111. # decoder: Attention decoder branch
  112. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
  113. encoder_out, encoder_out_lens, text, text_lengths
  114. )
  115. # 3. CTC-Att loss definition
  116. if self.ctc_weight == 0.0:
  117. loss = loss_att + loss_pre * self.predictor_weight
  118. else:
  119. loss = self.ctc_weight * loss_ctc + (
  120. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  121. # Collect Attn branch stats
  122. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  123. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  124. stats["acc"] = acc_att
  125. stats["cer"] = cer_att
  126. stats["wer"] = wer_att
  127. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  128. stats["loss"] = torch.clone(loss.detach())
  129. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  130. if self.length_normalized_loss:
  131. batch_size = (text_lengths + self.predictor_bias).sum()
  132. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  133. return loss, stats, weight
  134. def encode_chunk(
  135. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
  136. ) -> Tuple[torch.Tensor, torch.Tensor]:
  137. """Frontend + Encoder. Note that this method is used by asr_inference.py
  138. Args:
  139. speech: (Batch, Length, ...)
  140. speech_lengths: (Batch, )
  141. ind: int
  142. """
  143. with autocast(False):
  144. # Data augmentation
  145. if self.specaug is not None and self.training:
  146. speech, speech_lengths = self.specaug(speech, speech_lengths)
  147. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  148. if self.normalize is not None:
  149. speech, speech_lengths = self.normalize(speech, speech_lengths)
  150. # Forward encoder
  151. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
  152. if isinstance(encoder_out, tuple):
  153. encoder_out = encoder_out[0]
  154. return encoder_out, torch.tensor([encoder_out.size(1)])
  155. def _calc_att_predictor_loss(
  156. self,
  157. encoder_out: torch.Tensor,
  158. encoder_out_lens: torch.Tensor,
  159. ys_pad: torch.Tensor,
  160. ys_pad_lens: torch.Tensor,
  161. ):
  162. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  163. encoder_out.device)
  164. if self.predictor_bias == 1:
  165. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  166. ys_pad_lens = ys_pad_lens + self.predictor_bias
  167. mask_chunk_predictor = None
  168. if self.encoder.overlap_chunk_cls is not None:
  169. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  170. device=encoder_out.device,
  171. batch_size=encoder_out.size(
  172. 0))
  173. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  174. batch_size=encoder_out.size(0))
  175. encoder_out = encoder_out * mask_shfit_chunk
  176. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  177. ys_pad,
  178. encoder_out_mask,
  179. ignore_id=self.ignore_id,
  180. mask_chunk_predictor=mask_chunk_predictor,
  181. target_label_length=ys_pad_lens,
  182. )
  183. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  184. encoder_out_lens)
  185. scama_mask = None
  186. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  187. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  188. attention_chunk_center_bias = 0
  189. attention_chunk_size = encoder_chunk_size
  190. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  191. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  192. get_mask_shift_att_chunk_decoder(None,
  193. device=encoder_out.device,
  194. batch_size=encoder_out.size(0)
  195. )
  196. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  197. predictor_alignments=predictor_alignments,
  198. encoder_sequence_length=encoder_out_lens,
  199. chunk_size=1,
  200. encoder_chunk_size=encoder_chunk_size,
  201. attention_chunk_center_bias=attention_chunk_center_bias,
  202. attention_chunk_size=attention_chunk_size,
  203. attention_chunk_type=self.decoder_attention_chunk_type,
  204. step=None,
  205. predictor_mask_chunk_hopping=mask_chunk_predictor,
  206. decoder_att_look_back_factor=decoder_att_look_back_factor,
  207. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  208. target_length=ys_pad_lens,
  209. is_training=self.training,
  210. )
  211. elif self.encoder.overlap_chunk_cls is not None:
  212. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  213. encoder_out_lens,
  214. chunk_outs=None)
  215. # 0. sampler
  216. decoder_out_1st = None
  217. pre_loss_att = None
  218. if self.sampling_ratio > 0.0:
  219. if self.step_cur < 2:
  220. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  221. if self.use_1st_decoder_loss:
  222. sematic_embeds, decoder_out_1st, pre_loss_att = \
  223. self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
  224. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  225. else:
  226. sematic_embeds, decoder_out_1st = \
  227. self.sampler(encoder_out, encoder_out_lens, ys_pad,
  228. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  229. else:
  230. if self.step_cur < 2:
  231. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  232. sematic_embeds = pre_acoustic_embeds
  233. # 1. Forward decoder
  234. decoder_outs = self.decoder(
  235. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
  236. )
  237. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  238. if decoder_out_1st is None:
  239. decoder_out_1st = decoder_out
  240. # 2. Compute attention loss
  241. loss_att = self.criterion_att(decoder_out, ys_pad)
  242. acc_att = th_accuracy(
  243. decoder_out_1st.view(-1, self.vocab_size),
  244. ys_pad,
  245. ignore_label=self.ignore_id,
  246. )
  247. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  248. # Compute cer/wer using attention-decoder
  249. if self.training or self.error_calculator is None:
  250. cer_att, wer_att = None, None
  251. else:
  252. ys_hat = decoder_out_1st.argmax(dim=-1)
  253. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  254. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  255. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
  256. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  257. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  258. if self.share_embedding:
  259. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  260. else:
  261. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  262. with torch.no_grad():
  263. decoder_outs = self.decoder(
  264. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
  265. )
  266. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  267. pred_tokens = decoder_out.argmax(-1)
  268. nonpad_positions = ys_pad.ne(self.ignore_id)
  269. seq_lens = (nonpad_positions).sum(1)
  270. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  271. input_mask = torch.ones_like(nonpad_positions)
  272. bsz, seq_len = ys_pad.size()
  273. for li in range(bsz):
  274. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  275. if target_num > 0:
  276. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  277. input_mask = input_mask.eq(1)
  278. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  279. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  280. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  281. input_mask_expand_dim, 0)
  282. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  283. def calc_predictor(self, encoder_out, encoder_out_lens):
  284. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  285. encoder_out.device)
  286. mask_chunk_predictor = None
  287. if self.encoder.overlap_chunk_cls is not None:
  288. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  289. device=encoder_out.device,
  290. batch_size=encoder_out.size(
  291. 0))
  292. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  293. batch_size=encoder_out.size(0))
  294. encoder_out = encoder_out * mask_shfit_chunk
  295. pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
  296. None,
  297. encoder_out_mask,
  298. ignore_id=self.ignore_id,
  299. mask_chunk_predictor=mask_chunk_predictor,
  300. target_label_length=None,
  301. )
  302. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  303. encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
  304. scama_mask = None
  305. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  306. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  307. attention_chunk_center_bias = 0
  308. attention_chunk_size = encoder_chunk_size
  309. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  310. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
  311. get_mask_shift_att_chunk_decoder(None,
  312. device=encoder_out.device,
  313. batch_size=encoder_out.size(0)
  314. )
  315. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  316. predictor_alignments=predictor_alignments,
  317. encoder_sequence_length=encoder_out_lens,
  318. chunk_size=1,
  319. encoder_chunk_size=encoder_chunk_size,
  320. attention_chunk_center_bias=attention_chunk_center_bias,
  321. attention_chunk_size=attention_chunk_size,
  322. attention_chunk_type=self.decoder_attention_chunk_type,
  323. step=None,
  324. predictor_mask_chunk_hopping=mask_chunk_predictor,
  325. decoder_att_look_back_factor=decoder_att_look_back_factor,
  326. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  327. target_length=None,
  328. is_training=self.training,
  329. )
  330. self.scama_mask = scama_mask
  331. return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
  332. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
  333. is_final = kwargs.get("is_final", False)
  334. return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
  335. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  336. decoder_outs = self.decoder(
  337. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
  338. )
  339. decoder_out = decoder_outs[0]
  340. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  341. return decoder_out, ys_pad_lens
  342. def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
  343. decoder_outs = self.decoder.forward_chunk(
  344. encoder_out, sematic_embeds, cache["decoder"]
  345. )
  346. decoder_out = decoder_outs
  347. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  348. return decoder_out, ys_pad_lens
  349. def init_cache(self, cache: dict = {}, **kwargs):
  350. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  351. encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
  352. decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
  353. batch_size = 1
  354. enc_output_size = kwargs["encoder_conf"]["output_size"]
  355. feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
  356. cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
  357. "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
  358. "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
  359. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
  360. "tail_chunk": False}
  361. cache["encoder"] = cache_encoder
  362. cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
  363. "chunk_size": chunk_size}
  364. cache["decoder"] = cache_decoder
  365. cache["frontend"] = {}
  366. cache["prev_samples"] = torch.empty(0)
  367. return cache
  368. def generate_chunk(self,
  369. speech,
  370. speech_lengths=None,
  371. key: list = None,
  372. tokenizer=None,
  373. frontend=None,
  374. **kwargs,
  375. ):
  376. cache = kwargs.get("cache", {})
  377. speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
  378. # Encoder
  379. encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
  380. if isinstance(encoder_out, tuple):
  381. encoder_out = encoder_out[0]
  382. # predictor
  383. predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
  384. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  385. predictor_outs[2], predictor_outs[3]
  386. pre_token_length = pre_token_length.round().long()
  387. if torch.max(pre_token_length) < 1:
  388. return []
  389. decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
  390. encoder_out_lens,
  391. pre_acoustic_embeds,
  392. pre_token_length,
  393. cache=cache
  394. )
  395. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  396. results = []
  397. b, n, d = decoder_out.size()
  398. if isinstance(key[0], (list, tuple)):
  399. key = key[0]
  400. for i in range(b):
  401. x = encoder_out[i, :encoder_out_lens[i], :]
  402. am_scores = decoder_out[i, :pre_token_length[i], :]
  403. if self.beam_search is not None:
  404. nbest_hyps = self.beam_search(
  405. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  406. minlenratio=kwargs.get("minlenratio", 0.0)
  407. )
  408. nbest_hyps = nbest_hyps[: self.nbest]
  409. else:
  410. yseq = am_scores.argmax(dim=-1)
  411. score = am_scores.max(dim=-1)[0]
  412. score = torch.sum(score, dim=-1)
  413. # pad with mask tokens to ensure compatibility with sos/eos tokens
  414. yseq = torch.tensor(
  415. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  416. )
  417. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  418. for nbest_idx, hyp in enumerate(nbest_hyps):
  419. # remove sos/eos and get results
  420. last_pos = -1
  421. if isinstance(hyp.yseq, list):
  422. token_int = hyp.yseq[1:last_pos]
  423. else:
  424. token_int = hyp.yseq[1:last_pos].tolist()
  425. # remove blank symbol id, which is assumed to be 0
  426. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  427. # Change integer-ids to tokens
  428. token = tokenizer.ids2tokens(token_int)
  429. # text = tokenizer.tokens2text(token)
  430. result_i = token
  431. results.extend(result_i)
  432. return results
  433. def generate(self,
  434. data_in,
  435. data_lengths=None,
  436. key: list = None,
  437. tokenizer=None,
  438. frontend=None,
  439. cache: dict={},
  440. **kwargs,
  441. ):
  442. # init beamsearch
  443. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  444. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  445. if self.beam_search is None and (is_use_lm or is_use_ctc):
  446. logging.info("enable beam_search")
  447. self.init_beam_search(**kwargs)
  448. self.nbest = kwargs.get("nbest", 1)
  449. if len(cache) == 0:
  450. self.init_cache(cache, **kwargs)
  451. _is_final = kwargs.get("is_final", False)
  452. meta_data = {}
  453. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  454. chunk_stride_samples = chunk_size[1] * 960 # 600ms
  455. time1 = time.perf_counter()
  456. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
  457. data_type=kwargs.get("data_type", "sound"),
  458. tokenizer=tokenizer)
  459. time2 = time.perf_counter()
  460. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  461. assert len(audio_sample_list) == 1, "batch_size must be set 1"
  462. audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
  463. n = len(audio_sample) // chunk_stride_samples + int(_is_final)
  464. m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
  465. tokens = []
  466. for i in range(n):
  467. kwargs["is_final"] = _is_final and i == n -1
  468. audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
  469. # extract fbank feats
  470. speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
  471. frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
  472. time3 = time.perf_counter()
  473. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  474. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  475. tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
  476. tokens.extend(tokens_i)
  477. text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
  478. result_i = {"key": key[0], "text": text_postprocessed}
  479. result = [result_i]
  480. cache["prev_samples"] = audio_sample[:-m]
  481. if _is_final:
  482. self.init_cache(cache, **kwargs)
  483. if kwargs.get("output_dir"):
  484. writer = DatadirWriter(kwargs.get("output_dir"))
  485. ibest_writer = writer[f"{1}best_recog"]
  486. ibest_writer["token"][key[0]] = " ".join(tokens)
  487. ibest_writer["text"][key[0]] = text_postprocessed
  488. return result, meta_data