model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. import logging
  2. from typing import Dict
  3. from typing import List
  4. from typing import Optional
  5. from typing import Tuple
  6. from typing import Union
  7. import tempfile
  8. import codecs
  9. import requests
  10. import re
  11. import copy
  12. import torch
  13. import torch.nn as nn
  14. import random
  15. import numpy as np
  16. import time
  17. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  18. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  19. from funasr.metrics.compute_acc import th_accuracy
  20. from funasr.train_utils.device_funcs import force_gatherable
  21. from funasr.models.paraformer.search import Hypothesis
  22. from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
  23. from funasr.utils import postprocess_utils
  24. from funasr.utils.datadir_writer import DatadirWriter
  25. from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
  26. from funasr.register import tables
  27. from funasr.models.ctc.ctc import CTC
  28. from funasr.models.paraformer.model import Paraformer
  29. @tables.register("model_classes", "BiCifParaformer")
  30. class BiCifParaformer(Paraformer):
  31. """
  32. Author: Speech Lab of DAMO Academy, Alibaba Group
  33. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  34. https://arxiv.org/abs/2206.08317
  35. """
  36. def __init__(
  37. self,
  38. *args,
  39. **kwargs,
  40. ):
  41. super().__init__(*args, **kwargs)
  42. def _calc_pre2_loss(
  43. self,
  44. encoder_out: torch.Tensor,
  45. encoder_out_lens: torch.Tensor,
  46. ys_pad: torch.Tensor,
  47. ys_pad_lens: torch.Tensor,
  48. ):
  49. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  50. encoder_out.device)
  51. if self.predictor_bias == 1:
  52. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  53. ys_pad_lens = ys_pad_lens + self.predictor_bias
  54. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  55. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  56. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  57. return loss_pre2
  58. def _calc_att_loss(
  59. self,
  60. encoder_out: torch.Tensor,
  61. encoder_out_lens: torch.Tensor,
  62. ys_pad: torch.Tensor,
  63. ys_pad_lens: torch.Tensor,
  64. ):
  65. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  66. encoder_out.device)
  67. if self.predictor_bias == 1:
  68. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  69. ys_pad_lens = ys_pad_lens + self.predictor_bias
  70. pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
  71. encoder_out_mask,
  72. ignore_id=self.ignore_id)
  73. # 0. sampler
  74. decoder_out_1st = None
  75. if self.sampling_ratio > 0.0:
  76. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  77. pre_acoustic_embeds)
  78. else:
  79. sematic_embeds = pre_acoustic_embeds
  80. # 1. Forward decoder
  81. decoder_outs = self.decoder(
  82. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  83. )
  84. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  85. if decoder_out_1st is None:
  86. decoder_out_1st = decoder_out
  87. # 2. Compute attention loss
  88. loss_att = self.criterion_att(decoder_out, ys_pad)
  89. acc_att = th_accuracy(
  90. decoder_out_1st.view(-1, self.vocab_size),
  91. ys_pad,
  92. ignore_label=self.ignore_id,
  93. )
  94. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  95. # Compute cer/wer using attention-decoder
  96. if self.training or self.error_calculator is None:
  97. cer_att, wer_att = None, None
  98. else:
  99. ys_hat = decoder_out_1st.argmax(dim=-1)
  100. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  101. return loss_att, acc_att, cer_att, wer_att, loss_pre
  102. def calc_predictor(self, encoder_out, encoder_out_lens):
  103. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  104. encoder_out.device)
  105. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
  106. None,
  107. encoder_out_mask,
  108. ignore_id=self.ignore_id)
  109. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  110. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  111. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  112. encoder_out.device)
  113. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  114. encoder_out_mask,
  115. token_num)
  116. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  117. def forward(
  118. self,
  119. speech: torch.Tensor,
  120. speech_lengths: torch.Tensor,
  121. text: torch.Tensor,
  122. text_lengths: torch.Tensor,
  123. **kwargs,
  124. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  125. """Frontend + Encoder + Decoder + Calc loss
  126. Args:
  127. speech: (Batch, Length, ...)
  128. speech_lengths: (Batch, )
  129. text: (Batch, Length)
  130. text_lengths: (Batch,)
  131. """
  132. if len(text_lengths.size()) > 1:
  133. text_lengths = text_lengths[:, 0]
  134. if len(speech_lengths.size()) > 1:
  135. speech_lengths = speech_lengths[:, 0]
  136. batch_size = speech.shape[0]
  137. # Encoder
  138. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  139. loss_ctc, cer_ctc = None, None
  140. loss_pre = None
  141. stats = dict()
  142. # decoder: CTC branch
  143. if self.ctc_weight != 0.0:
  144. loss_ctc, cer_ctc = self._calc_ctc_loss(
  145. encoder_out, encoder_out_lens, text, text_lengths
  146. )
  147. # Collect CTC branch stats
  148. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  149. stats["cer_ctc"] = cer_ctc
  150. # decoder: Attention decoder branch
  151. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  152. encoder_out, encoder_out_lens, text, text_lengths
  153. )
  154. loss_pre2 = self._calc_pre2_loss(
  155. encoder_out, encoder_out_lens, text, text_lengths
  156. )
  157. # 3. CTC-Att loss definition
  158. if self.ctc_weight == 0.0:
  159. loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  160. else:
  161. loss = self.ctc_weight * loss_ctc + (
  162. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  163. # Collect Attn branch stats
  164. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  165. stats["acc"] = acc_att
  166. stats["cer"] = cer_att
  167. stats["wer"] = wer_att
  168. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  169. stats["loss_pre2"] = loss_pre2.detach().cpu()
  170. stats["loss"] = torch.clone(loss.detach())
  171. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  172. if self.length_normalized_loss:
  173. batch_size = int((text_lengths + self.predictor_bias).sum())
  174. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  175. return loss, stats, weight
  176. def generate(self,
  177. data_in,
  178. data_lengths=None,
  179. key: list = None,
  180. tokenizer=None,
  181. frontend=None,
  182. **kwargs,
  183. ):
  184. # init beamsearch
  185. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  186. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  187. if self.beam_search is None and (is_use_lm or is_use_ctc):
  188. logging.info("enable beam_search")
  189. self.init_beam_search(**kwargs)
  190. self.nbest = kwargs.get("nbest", 1)
  191. meta_data = {}
  192. if isinstance(data_in, torch.Tensor): # fbank
  193. speech, speech_lengths = data_in, data_lengths
  194. if len(speech.shape) < 3:
  195. speech = speech[None, :, :]
  196. if speech_lengths is None:
  197. speech_lengths = speech.shape[1]
  198. else:
  199. # extract fbank feats
  200. time1 = time.perf_counter()
  201. audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
  202. time2 = time.perf_counter()
  203. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  204. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
  205. frontend=frontend)
  206. time3 = time.perf_counter()
  207. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  208. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  209. speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
  210. # Encoder
  211. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  212. if isinstance(encoder_out, tuple):
  213. encoder_out = encoder_out[0]
  214. # predictor
  215. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  216. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  217. predictor_outs[2], predictor_outs[3]
  218. pre_token_length = pre_token_length.round().long()
  219. if torch.max(pre_token_length) < 1:
  220. return []
  221. decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
  222. pre_token_length)
  223. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  224. # BiCifParaformer, test no bias cif2
  225. _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
  226. pre_token_length)
  227. results = []
  228. b, n, d = decoder_out.size()
  229. for i in range(b):
  230. x = encoder_out[i, :encoder_out_lens[i], :]
  231. am_scores = decoder_out[i, :pre_token_length[i], :]
  232. if self.beam_search is not None:
  233. nbest_hyps = self.beam_search(
  234. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
  235. minlenratio=kwargs.get("minlenratio", 0.0)
  236. )
  237. nbest_hyps = nbest_hyps[: self.nbest]
  238. else:
  239. yseq = am_scores.argmax(dim=-1)
  240. score = am_scores.max(dim=-1)[0]
  241. score = torch.sum(score, dim=-1)
  242. # pad with mask tokens to ensure compatibility with sos/eos tokens
  243. yseq = torch.tensor(
  244. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  245. )
  246. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  247. for nbest_idx, hyp in enumerate(nbest_hyps):
  248. ibest_writer = None
  249. if ibest_writer is None and kwargs.get("output_dir") is not None:
  250. writer = DatadirWriter(kwargs.get("output_dir"))
  251. ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
  252. # remove sos/eos and get results
  253. last_pos = -1
  254. if isinstance(hyp.yseq, list):
  255. token_int = hyp.yseq[1:last_pos]
  256. else:
  257. token_int = hyp.yseq[1:last_pos].tolist()
  258. # remove blank symbol id, which is assumed to be 0
  259. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  260. if tokenizer is not None:
  261. # Change integer-ids to tokens
  262. token = tokenizer.ids2tokens(token_int)
  263. text = tokenizer.tokens2text(token)
  264. _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
  265. us_peaks[i][:encoder_out_lens[i] * 3],
  266. copy.copy(token),
  267. vad_offset=kwargs.get("begin_time", 0))
  268. text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
  269. token, timestamp)
  270. result_i = {"key": key[i], "text": text_postprocessed,
  271. "timestamp": time_stamp_postprocessed,
  272. }
  273. if ibest_writer is not None:
  274. ibest_writer["token"][key[i]] = " ".join(token)
  275. # ibest_writer["text"][key[i]] = text
  276. ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
  277. ibest_writer["text"][key[i]] = text_postprocessed
  278. else:
  279. result_i = {"key": key[i], "token_int": token_int}
  280. results.append(result_i)
  281. return results, meta_data