model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import copy
  6. import time
  7. import torch
  8. import logging
  9. from contextlib import contextmanager
  10. from distutils.version import LooseVersion
  11. from typing import Dict, List, Optional, Tuple
  12. from funasr.register import tables
  13. from funasr.models.ctc.ctc import CTC
  14. from funasr.utils import postprocess_utils
  15. from funasr.metrics.compute_acc import th_accuracy
  16. from funasr.utils.datadir_writer import DatadirWriter
  17. from funasr.models.paraformer.model import Paraformer
  18. from funasr.models.paraformer.search import Hypothesis
  19. from funasr.train_utils.device_funcs import force_gatherable
  20. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  21. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  22. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  23. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  24. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  25. from torch.cuda.amp import autocast
  26. else:
  27. # Nothing to do if torch<1.6.0
  28. @contextmanager
  29. def autocast(enabled=True):
  30. yield
  31. @tables.register("model_classes", "BiCifParaformer")
  32. class BiCifParaformer(Paraformer):
  33. """
  34. Author: Speech Lab of DAMO Academy, Alibaba Group
  35. Paper1: FunASR: A Fundamental End-to-End Speech Recognition Toolkit
  36. https://arxiv.org/abs/2305.11013
  37. Paper2: Achieving timestamp prediction while recognizing with non-autoregressive end-to-end ASR model
  38. https://arxiv.org/abs/2301.12343
  39. """
  40. def __init__(
  41. self,
  42. *args,
  43. **kwargs,
  44. ):
  45. super().__init__(*args, **kwargs)
  46. def _calc_pre2_loss(
  47. self,
  48. encoder_out: torch.Tensor,
  49. encoder_out_lens: torch.Tensor,
  50. ys_pad: torch.Tensor,
  51. ys_pad_lens: torch.Tensor,
  52. ):
  53. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  54. encoder_out.device)
  55. if self.predictor_bias == 1:
  56. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  57. ys_pad_lens = ys_pad_lens + self.predictor_bias
  58. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  59. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  60. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  61. return loss_pre2
  62. def _calc_att_loss(
  63. self,
  64. encoder_out: torch.Tensor,
  65. encoder_out_lens: torch.Tensor,
  66. ys_pad: torch.Tensor,
  67. ys_pad_lens: torch.Tensor,
  68. ):
  69. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  70. encoder_out.device)
  71. if self.predictor_bias == 1:
  72. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  73. ys_pad_lens = ys_pad_lens + self.predictor_bias
  74. pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
  75. encoder_out_mask,
  76. ignore_id=self.ignore_id)
  77. # 0. sampler
  78. decoder_out_1st = None
  79. if self.sampling_ratio > 0.0:
  80. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  81. pre_acoustic_embeds)
  82. else:
  83. sematic_embeds = pre_acoustic_embeds
  84. # 1. Forward decoder
  85. decoder_outs = self.decoder(
  86. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  87. )
  88. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  89. if decoder_out_1st is None:
  90. decoder_out_1st = decoder_out
  91. # 2. Compute attention loss
  92. loss_att = self.criterion_att(decoder_out, ys_pad)
  93. acc_att = th_accuracy(
  94. decoder_out_1st.view(-1, self.vocab_size),
  95. ys_pad,
  96. ignore_label=self.ignore_id,
  97. )
  98. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  99. # Compute cer/wer using attention-decoder
  100. if self.training or self.error_calculator is None:
  101. cer_att, wer_att = None, None
  102. else:
  103. ys_hat = decoder_out_1st.argmax(dim=-1)
  104. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  105. return loss_att, acc_att, cer_att, wer_att, loss_pre
  106. def calc_predictor(self, encoder_out, encoder_out_lens):
  107. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  108. encoder_out.device)
  109. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
  110. None,
  111. encoder_out_mask,
  112. ignore_id=self.ignore_id)
  113. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  114. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  115. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  116. encoder_out.device)
  117. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  118. encoder_out_mask,
  119. token_num)
  120. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  121. def forward(
  122. self,
  123. speech: torch.Tensor,
  124. speech_lengths: torch.Tensor,
  125. text: torch.Tensor,
  126. text_lengths: torch.Tensor,
  127. **kwargs,
  128. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  129. """Frontend + Encoder + Decoder + Calc loss
  130. Args:
  131. speech: (Batch, Length, ...)
  132. speech_lengths: (Batch, )
  133. text: (Batch, Length)
  134. text_lengths: (Batch,)
  135. """
  136. if len(text_lengths.size()) > 1:
  137. text_lengths = text_lengths[:, 0]
  138. if len(speech_lengths.size()) > 1:
  139. speech_lengths = speech_lengths[:, 0]
  140. batch_size = speech.shape[0]
  141. # Encoder
  142. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  143. loss_ctc, cer_ctc = None, None
  144. loss_pre = None
  145. stats = dict()
  146. # decoder: CTC branch
  147. if self.ctc_weight != 0.0:
  148. loss_ctc, cer_ctc = self._calc_ctc_loss(
  149. encoder_out, encoder_out_lens, text, text_lengths
  150. )
  151. # Collect CTC branch stats
  152. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  153. stats["cer_ctc"] = cer_ctc
  154. # decoder: Attention decoder branch
  155. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  156. encoder_out, encoder_out_lens, text, text_lengths
  157. )
  158. loss_pre2 = self._calc_pre2_loss(
  159. encoder_out, encoder_out_lens, text, text_lengths
  160. )
  161. # 3. CTC-Att loss definition
  162. if self.ctc_weight == 0.0:
  163. loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  164. else:
  165. loss = self.ctc_weight * loss_ctc + (
  166. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  167. # Collect Attn branch stats
  168. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  169. stats["acc"] = acc_att
  170. stats["cer"] = cer_att
  171. stats["wer"] = wer_att
  172. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  173. stats["loss_pre2"] = loss_pre2.detach().cpu()
  174. stats["loss"] = torch.clone(loss.detach())
  175. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  176. if self.length_normalized_loss:
  177. batch_size = int((text_lengths + self.predictor_bias).sum())
  178. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  179. return loss, stats, weight
  180. def inference(self,
  181. data_in,
  182. data_lengths=None,
  183. key: list = None,
  184. tokenizer=None,
  185. frontend=None,
  186. **kwargs,
  187. ):
  188. # init beamsearch
  189. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  190. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  191. if self.beam_search is None and (is_use_lm or is_use_ctc):
  192. logging.info("enable beam_search")
  193. self.init_beam_search(**kwargs)
  194. self.nbest = kwargs.get("nbest", 1)
  195. meta_data = {}
  196. # if isinstance(data_in, torch.Tensor): # fbank
  197. # speech, speech_lengths = data_in, data_lengths
  198. # if len(speech.shape) < 3:
  199. # speech = speech[None, :, :]
  200. # if speech_lengths is None:
  201. # speech_lengths = speech.shape[1]
  202. # else:
  203. # extract fbank feats
  204. time1 = time.perf_counter()
  205. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
  206. time2 = time.perf_counter()
  207. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  208. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  209. frontend=frontend)
  210. time3 = time.perf_counter()
  211. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  212. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  213. speech = speech.to(device=kwargs["device"])
  214. speech_lengths = speech_lengths.to(device=kwargs["device"])
  215. # Encoder
  216. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  217. if isinstance(encoder_out, tuple):
  218. encoder_out = encoder_out[0]
  219. # predictor
  220. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  221. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  222. predictor_outs[2], predictor_outs[3]
  223. pre_token_length = pre_token_length.round().long()
  224. if torch.max(pre_token_length) < 1:
  225. return []
  226. decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
  227. pre_token_length)
  228. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  229. # BiCifParaformer, test no bias cif2
  230. _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
  231. pre_token_length)
  232. results = []
  233. b, n, d = decoder_out.size()
  234. for i in range(b):
  235. x = encoder_out[i, :encoder_out_lens[i], :]
  236. am_scores = decoder_out[i, :pre_token_length[i], :]
  237. if self.beam_search is not None:
  238. nbest_hyps = self.beam_search(
  239. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  240. minlenratio=kwargs.get("minlenratio", 0.0)
  241. )
  242. nbest_hyps = nbest_hyps[: self.nbest]
  243. else:
  244. yseq = am_scores.argmax(dim=-1)
  245. score = am_scores.max(dim=-1)[0]
  246. score = torch.sum(score, dim=-1)
  247. # pad with mask tokens to ensure compatibility with sos/eos tokens
  248. yseq = torch.tensor(
  249. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  250. )
  251. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  252. for nbest_idx, hyp in enumerate(nbest_hyps):
  253. ibest_writer = None
  254. if kwargs.get("output_dir") is not None:
  255. if not hasattr(self, "writer"):
  256. self.writer = DatadirWriter(kwargs.get("output_dir"))
  257. ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
  258. # remove sos/eos and get results
  259. last_pos = -1
  260. if isinstance(hyp.yseq, list):
  261. token_int = hyp.yseq[1:last_pos]
  262. else:
  263. token_int = hyp.yseq[1:last_pos].tolist()
  264. # remove blank symbol id, which is assumed to be 0
  265. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  266. if tokenizer is not None:
  267. # Change integer-ids to tokens
  268. token = tokenizer.ids2tokens(token_int)
  269. text = tokenizer.tokens2text(token)
  270. _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
  271. us_peaks[i][:encoder_out_lens[i] * 3],
  272. copy.copy(token),
  273. vad_offset=kwargs.get("begin_time", 0))
  274. text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
  275. token, timestamp)
  276. result_i = {"key": key[i], "text": text_postprocessed,
  277. "timestamp": time_stamp_postprocessed,
  278. }
  279. if ibest_writer is not None:
  280. ibest_writer["token"][key[i]] = " ".join(token)
  281. # ibest_writer["text"][key[i]] = text
  282. ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
  283. ibest_writer["text"][key[i]] = text_postprocessed
  284. else:
  285. result_i = {"key": key[i], "token_int": token_int}
  286. results.append(result_i)
  287. return results, meta_data