model.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import time
  6. import torch
  7. import torch.nn as nn
  8. import torch.functional as F
  9. import logging
  10. from typing import Dict, Tuple
  11. from contextlib import contextmanager
  12. from distutils.version import LooseVersion
  13. from funasr.register import tables
  14. from funasr.models.ctc.ctc import CTC
  15. from funasr.utils import postprocess_utils
  16. from funasr.metrics.compute_acc import th_accuracy
  17. from funasr.utils.datadir_writer import DatadirWriter
  18. from funasr.models.paraformer.model import Paraformer
  19. from funasr.models.paraformer.search import Hypothesis
  20. from funasr.models.paraformer.cif_predictor import mae_loss
  21. from funasr.train_utils.device_funcs import force_gatherable
  22. from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
  23. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  24. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  25. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  26. from funasr.models.scama.utils import sequence_mask
  27. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  28. from torch.cuda.amp import autocast
  29. else:
  30. # Nothing to do if torch<1.6.0
  31. @contextmanager
  32. def autocast(enabled=True):
  33. yield
  34. @tables.register("model_classes", "SCAMA")
  35. class SCAMA(nn.Module):
  36. """
  37. Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
  38. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  39. https://arxiv.org/abs/2006.01712
  40. """
  41. def __init__(
  42. self,
  43. specaug: str = None,
  44. specaug_conf: dict = None,
  45. normalize: str = None,
  46. normalize_conf: dict = None,
  47. encoder: str = None,
  48. encoder_conf: dict = None,
  49. decoder: str = None,
  50. decoder_conf: dict = None,
  51. ctc: str = None,
  52. ctc_conf: dict = None,
  53. ctc_weight: float = 0.5,
  54. predictor: str = None,
  55. predictor_conf: dict = None,
  56. predictor_bias: int = 0,
  57. predictor_weight: float = 0.0,
  58. input_size: int = 80,
  59. vocab_size: int = -1,
  60. ignore_id: int = -1,
  61. blank_id: int = 0,
  62. sos: int = 1,
  63. eos: int = 2,
  64. lsm_weight: float = 0.0,
  65. length_normalized_loss: bool = False,
  66. share_embedding: bool = False,
  67. **kwargs,
  68. ):
  69. super().__init__()
  70. if specaug is not None:
  71. specaug_class = tables.specaug_classes.get(specaug)
  72. specaug = specaug_class(**specaug_conf)
  73. if normalize is not None:
  74. normalize_class = tables.normalize_classes.get(normalize)
  75. normalize = normalize_class(**normalize_conf)
  76. encoder_class = tables.encoder_classes.get(encoder)
  77. encoder = encoder_class(input_size=input_size, **encoder_conf)
  78. encoder_output_size = encoder.output_size()
  79. decoder_class = tables.decoder_classes.get(decoder)
  80. decoder = decoder_class(
  81. vocab_size=vocab_size,
  82. encoder_output_size=encoder_output_size,
  83. **decoder_conf,
  84. )
  85. if ctc_weight > 0.0:
  86. if ctc_conf is None:
  87. ctc_conf = {}
  88. ctc = CTC(
  89. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  90. )
  91. predictor_class = tables.predictor_classes.get(predictor)
  92. predictor = predictor_class(**predictor_conf)
  93. # note that eos is the same as sos (equivalent ID)
  94. self.blank_id = blank_id
  95. self.sos = sos if sos is not None else vocab_size - 1
  96. self.eos = eos if eos is not None else vocab_size - 1
  97. self.vocab_size = vocab_size
  98. self.ignore_id = ignore_id
  99. self.ctc_weight = ctc_weight
  100. self.specaug = specaug
  101. self.normalize = normalize
  102. self.encoder = encoder
  103. if ctc_weight == 1.0:
  104. self.decoder = None
  105. else:
  106. self.decoder = decoder
  107. self.criterion_att = LabelSmoothingLoss(
  108. size=vocab_size,
  109. padding_idx=ignore_id,
  110. smoothing=lsm_weight,
  111. normalize_length=length_normalized_loss,
  112. )
  113. if ctc_weight == 0.0:
  114. self.ctc = None
  115. else:
  116. self.ctc = ctc
  117. self.predictor = predictor
  118. self.predictor_weight = predictor_weight
  119. self.predictor_bias = predictor_bias
  120. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  121. self.share_embedding = share_embedding
  122. if self.share_embedding:
  123. self.decoder.embed = None
  124. self.length_normalized_loss = length_normalized_loss
  125. self.beam_search = None
  126. self.error_calculator = None
  127. if self.encoder.overlap_chunk_cls is not None:
  128. from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  129. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  130. self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
  131. def forward(
  132. self,
  133. speech: torch.Tensor,
  134. speech_lengths: torch.Tensor,
  135. text: torch.Tensor,
  136. text_lengths: torch.Tensor,
  137. **kwargs,
  138. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  139. """Encoder + Decoder + Calc loss
  140. Args:
  141. speech: (Batch, Length, ...)
  142. speech_lengths: (Batch, )
  143. text: (Batch, Length)
  144. text_lengths: (Batch,)
  145. """
  146. decoding_ind = kwargs.get("decoding_ind")
  147. if len(text_lengths.size()) > 1:
  148. text_lengths = text_lengths[:, 0]
  149. if len(speech_lengths.size()) > 1:
  150. speech_lengths = speech_lengths[:, 0]
  151. batch_size = speech.shape[0]
  152. # Encoder
  153. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  154. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  155. loss_ctc, cer_ctc = None, None
  156. loss_pre = None
  157. stats = dict()
  158. # decoder: CTC branch
  159. if self.ctc_weight > 0.0:
  160. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  161. encoder_out_lens,
  162. chunk_outs=None)
  163. loss_ctc, cer_ctc = self._calc_ctc_loss(
  164. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  165. )
  166. # Collect CTC branch stats
  167. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  168. stats["cer_ctc"] = cer_ctc
  169. # decoder: Attention decoder branch
  170. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  171. encoder_out, encoder_out_lens, text, text_lengths
  172. )
  173. # 3. CTC-Att loss definition
  174. if self.ctc_weight == 0.0:
  175. loss = loss_att + loss_pre * self.predictor_weight
  176. else:
  177. loss = self.ctc_weight * loss_ctc + (
  178. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  179. # Collect Attn branch stats
  180. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  181. stats["acc"] = acc_att
  182. stats["cer"] = cer_att
  183. stats["wer"] = wer_att
  184. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  185. stats["loss"] = torch.clone(loss.detach())
  186. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  187. if self.length_normalized_loss:
  188. batch_size = (text_lengths + self.predictor_bias).sum()
  189. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  190. return loss, stats, weight
  191. def encode(
  192. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  193. ) -> Tuple[torch.Tensor, torch.Tensor]:
  194. """Encoder. Note that this method is used by asr_inference.py
  195. Args:
  196. speech: (Batch, Length, ...)
  197. speech_lengths: (Batch, )
  198. ind: int
  199. """
  200. with autocast(False):
  201. # Data augmentation
  202. if self.specaug is not None and self.training:
  203. speech, speech_lengths = self.specaug(speech, speech_lengths)
  204. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  205. if self.normalize is not None:
  206. speech, speech_lengths = self.normalize(speech, speech_lengths)
  207. # Forward encoder
  208. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  209. if isinstance(encoder_out, tuple):
  210. encoder_out = encoder_out[0]
  211. return encoder_out, encoder_out_lens
  212. def encode_chunk(
  213. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
  214. ) -> Tuple[torch.Tensor, torch.Tensor]:
  215. """Frontend + Encoder. Note that this method is used by asr_inference.py
  216. Args:
  217. speech: (Batch, Length, ...)
  218. speech_lengths: (Batch, )
  219. ind: int
  220. """
  221. with autocast(False):
  222. # Data augmentation
  223. if self.specaug is not None and self.training:
  224. speech, speech_lengths = self.specaug(speech, speech_lengths)
  225. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  226. if self.normalize is not None:
  227. speech, speech_lengths = self.normalize(speech, speech_lengths)
  228. # Forward encoder
  229. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
  230. if isinstance(encoder_out, tuple):
  231. encoder_out = encoder_out[0]
  232. return encoder_out, torch.tensor([encoder_out.size(1)])
  233. def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
  234. is_final = kwargs.get("is_final", False)
  235. return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
  236. def _calc_att_predictor_loss(
  237. self,
  238. encoder_out: torch.Tensor,
  239. encoder_out_lens: torch.Tensor,
  240. ys_pad: torch.Tensor,
  241. ys_pad_lens: torch.Tensor,
  242. ):
  243. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  244. ys_in_lens = ys_pad_lens + 1
  245. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  246. device=encoder_out.device)[:, None, :]
  247. mask_chunk_predictor = None
  248. if self.encoder.overlap_chunk_cls is not None:
  249. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  250. device=encoder_out.device,
  251. batch_size=encoder_out.size(
  252. 0))
  253. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  254. batch_size=encoder_out.size(0))
  255. encoder_out = encoder_out * mask_shfit_chunk
  256. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  257. ys_out_pad,
  258. encoder_out_mask,
  259. ignore_id=self.ignore_id,
  260. mask_chunk_predictor=mask_chunk_predictor,
  261. target_label_length=ys_in_lens,
  262. )
  263. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  264. encoder_out_lens)
  265. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  266. attention_chunk_center_bias = 0
  267. attention_chunk_size = encoder_chunk_size
  268. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  269. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  270. device=encoder_out.device,
  271. batch_size=encoder_out.size(
  272. 0))
  273. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  274. predictor_alignments=predictor_alignments,
  275. encoder_sequence_length=encoder_out_lens,
  276. chunk_size=1,
  277. encoder_chunk_size=encoder_chunk_size,
  278. attention_chunk_center_bias=attention_chunk_center_bias,
  279. attention_chunk_size=attention_chunk_size,
  280. attention_chunk_type=self.decoder_attention_chunk_type,
  281. step=None,
  282. predictor_mask_chunk_hopping=mask_chunk_predictor,
  283. decoder_att_look_back_factor=decoder_att_look_back_factor,
  284. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  285. target_length=ys_in_lens,
  286. is_training=self.training,
  287. )
  288. # try:
  289. # 1. Forward decoder
  290. decoder_out, _ = self.decoder(
  291. encoder_out,
  292. encoder_out_lens,
  293. ys_in_pad,
  294. ys_in_lens,
  295. chunk_mask=scama_mask,
  296. pre_acoustic_embeds=pre_acoustic_embeds,
  297. )
  298. # 2. Compute attention loss
  299. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  300. acc_att = th_accuracy(
  301. decoder_out.view(-1, self.vocab_size),
  302. ys_out_pad,
  303. ignore_label=self.ignore_id,
  304. )
  305. # predictor loss
  306. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  307. # Compute cer/wer using attention-decoder
  308. if self.training or self.error_calculator is None:
  309. cer_att, wer_att = None, None
  310. else:
  311. ys_hat = decoder_out.argmax(dim=-1)
  312. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  313. return loss_att, acc_att, cer_att, wer_att, loss_pre
  314. def calc_predictor_mask(
  315. self,
  316. encoder_out: torch.Tensor,
  317. encoder_out_lens: torch.Tensor,
  318. ys_pad: torch.Tensor = None,
  319. ys_pad_lens: torch.Tensor = None,
  320. ):
  321. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  322. # ys_in_lens = ys_pad_lens + 1
  323. ys_out_pad, ys_in_lens = None, None
  324. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  325. device=encoder_out.device)[:, None, :]
  326. mask_chunk_predictor = None
  327. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  328. device=encoder_out.device,
  329. batch_size=encoder_out.size(
  330. 0))
  331. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  332. batch_size=encoder_out.size(0))
  333. encoder_out = encoder_out * mask_shfit_chunk
  334. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  335. ys_out_pad,
  336. encoder_out_mask,
  337. ignore_id=self.ignore_id,
  338. mask_chunk_predictor=mask_chunk_predictor,
  339. target_label_length=ys_in_lens,
  340. )
  341. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  342. encoder_out_lens)
  343. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  344. attention_chunk_center_bias = 0
  345. attention_chunk_size = encoder_chunk_size
  346. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  347. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  348. device=encoder_out.device,
  349. batch_size=encoder_out.size(
  350. 0))
  351. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  352. predictor_alignments=predictor_alignments,
  353. encoder_sequence_length=encoder_out_lens,
  354. chunk_size=1,
  355. encoder_chunk_size=encoder_chunk_size,
  356. attention_chunk_center_bias=attention_chunk_center_bias,
  357. attention_chunk_size=attention_chunk_size,
  358. attention_chunk_type=self.decoder_attention_chunk_type,
  359. step=None,
  360. predictor_mask_chunk_hopping=mask_chunk_predictor,
  361. decoder_att_look_back_factor=decoder_att_look_back_factor,
  362. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  363. target_length=ys_in_lens,
  364. is_training=self.training,
  365. )
  366. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  367. def init_beam_search(self,
  368. **kwargs,
  369. ):
  370. from funasr.models.scama.beam_search import BeamSearchScamaStreaming
  371. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  372. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  373. # 1. Build ASR model
  374. scorers = {}
  375. if self.ctc != None:
  376. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  377. scorers.update(
  378. ctc=ctc
  379. )
  380. token_list = kwargs.get("token_list")
  381. scorers.update(
  382. decoder=self.decoder,
  383. length_bonus=LengthBonus(len(token_list)),
  384. )
  385. # 3. Build ngram model
  386. # ngram is not supported now
  387. ngram = None
  388. scorers["ngram"] = ngram
  389. weights = dict(
  390. decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
  391. ctc=kwargs.get("decoding_ctc_weight", 0.0),
  392. lm=kwargs.get("lm_weight", 0.0),
  393. ngram=kwargs.get("ngram_weight", 0.0),
  394. length_bonus=kwargs.get("penalty", 0.0),
  395. )
  396. beam_search = BeamSearchScamaStreaming(
  397. beam_size=kwargs.get("beam_size", 2),
  398. weights=weights,
  399. scorers=scorers,
  400. sos=self.sos,
  401. eos=self.eos,
  402. vocab_size=len(token_list),
  403. token_list=token_list,
  404. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  405. )
  406. # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  407. # for scorer in scorers.values():
  408. # if isinstance(scorer, torch.nn.Module):
  409. # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  410. self.beam_search = beam_search
  411. def generate_chunk(self,
  412. speech,
  413. speech_lengths=None,
  414. key: list = None,
  415. tokenizer=None,
  416. frontend=None,
  417. **kwargs,
  418. ):
  419. cache = kwargs.get("cache", {})
  420. speech = speech.to(device=kwargs["device"])
  421. speech_lengths = speech_lengths.to(device=kwargs["device"])
  422. # Encoder
  423. encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache,
  424. is_final=kwargs.get("is_final", False))
  425. if isinstance(encoder_out, tuple):
  426. encoder_out = encoder_out[0]
  427. if "running_hyps" not in cache:
  428. running_hyps = self.beam_search.init_hyp(encoder_out)
  429. cache["running_hyps"] = running_hyps
  430. # predictor
  431. predictor_outs = self.calc_predictor_chunk(encoder_out,
  432. encoder_out_lens,
  433. cache=cache,
  434. is_final=kwargs.get("is_final", False),
  435. )
  436. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  437. predictor_outs[2], predictor_outs[3]
  438. pre_token_length = pre_token_length.round().long()
  439. if torch.max(pre_token_length) < 1:
  440. return []
  441. maxlen = minlen = pre_token_length
  442. if kwargs.get("is_final", False):
  443. maxlen += kwargs.get("token_num_relax", 5)
  444. minlen = max(0, minlen - kwargs.get("token_num_relax", 5))
  445. # c. Passed the encoder result and the beam search
  446. nbest_hyps = self.beam_search(
  447. x=encoder_out[0], scama_mask=None, pre_acoustic_embeds=pre_acoustic_embeds, maxlen=int(maxlen), minlen=int(minlen), cache=cache,
  448. )
  449. cache["running_hyps"] = nbest_hyps
  450. nbest_hyps = nbest_hyps[: self.nbest]
  451. results = []
  452. for hyp in nbest_hyps:
  453. # assert isinstance(hyp, (Hypothesis)), type(hyp)
  454. # remove sos/eos and get results
  455. last_pos = -1
  456. if isinstance(hyp.yseq, list):
  457. token_int = hyp.yseq[1:last_pos]
  458. else:
  459. token_int = hyp.yseq[1:last_pos].tolist()
  460. # remove blank symbol id, which is assumed to be 0
  461. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  462. # Change integer-ids to tokens
  463. token = tokenizer.ids2tokens(token_int)
  464. # text = tokenizer.tokens2text(token)
  465. result_i = token
  466. results.extend(result_i)
  467. return results
  468. def init_cache(self, cache: dict = {}, **kwargs):
  469. device = kwargs.get("device", "cuda")
  470. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  471. encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
  472. decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
  473. batch_size = 1
  474. enc_output_size = kwargs["encoder_conf"]["output_size"]
  475. feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
  476. cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)).to(device=device),
  477. "cif_alphas": torch.zeros((batch_size, 1)).to(device=device), "chunk_size": chunk_size,
  478. "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
  479. "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)).to(device=device),
  480. "tail_chunk": False}
  481. cache["encoder"] = cache_encoder
  482. cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
  483. "chunk_size": chunk_size}
  484. cache["decoder"] = cache_decoder
  485. cache["frontend"] = {}
  486. cache["prev_samples"] = torch.empty(0).to(device=device)
  487. return cache
  488. def inference(self,
  489. data_in,
  490. data_lengths=None,
  491. key: list = None,
  492. tokenizer=None,
  493. frontend=None,
  494. cache: dict = {},
  495. **kwargs,
  496. ):
  497. # init beamsearch
  498. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  499. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  500. if self.beam_search is None:
  501. logging.info("enable beam_search")
  502. self.init_beam_search(**kwargs)
  503. self.nbest = kwargs.get("nbest", 1)
  504. if len(cache) == 0:
  505. self.init_cache(cache, **kwargs)
  506. meta_data = {}
  507. chunk_size = kwargs.get("chunk_size", [0, 10, 5])
  508. chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
  509. time1 = time.perf_counter()
  510. cfg = {"is_final": kwargs.get("is_final", False)}
  511. audio_sample_list = load_audio_text_image_video(data_in,
  512. fs=frontend.fs,
  513. audio_fs=kwargs.get("fs", 16000),
  514. data_type=kwargs.get("data_type", "sound"),
  515. tokenizer=tokenizer,
  516. cache=cfg,
  517. )
  518. _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
  519. time2 = time.perf_counter()
  520. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  521. assert len(audio_sample_list) == 1, "batch_size must be set 1"
  522. audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
  523. n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
  524. m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
  525. tokens = []
  526. for i in range(n):
  527. kwargs["is_final"] = _is_final and i == n - 1
  528. audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
  529. # extract fbank feats
  530. speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
  531. frontend=frontend, cache=cache["frontend"],
  532. is_final=kwargs["is_final"])
  533. time3 = time.perf_counter()
  534. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  535. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  536. tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache,
  537. frontend=frontend, **kwargs)
  538. tokens.extend(tokens_i)
  539. text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
  540. result_i = {"key": key[0], "text": text_postprocessed}
  541. result = [result_i]
  542. cache["prev_samples"] = audio_sample[:-m]
  543. if _is_final:
  544. self.init_cache(cache, **kwargs)
  545. if kwargs.get("output_dir"):
  546. writer = DatadirWriter(kwargs.get("output_dir"))
  547. ibest_writer = writer[f"{1}best_recog"]
  548. ibest_writer["token"][key[0]] = " ".join(tokens)
  549. ibest_writer["text"][key[0]] = text_postprocessed
  550. return result, meta_data