model.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. import os
  2. import logging
  3. from typing import Union, Dict, List, Tuple, Optional
  4. import torch
  5. import torch.nn as nn
  6. import time
  7. from funasr.losses.label_smoothing_loss import (
  8. LabelSmoothingLoss, # noqa: H301
  9. )
  10. from funasr.models.paraformer.cif_predictor import mae_loss
  11. from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
  12. from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
  13. from funasr.metrics.compute_acc import th_accuracy
  14. from funasr.train_utils.device_funcs import force_gatherable
  15. from funasr.models.paraformer.search import Hypothesis
  16. from torch.cuda.amp import autocast
  17. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  18. from funasr.utils import postprocess_utils
  19. from funasr.utils.datadir_writer import DatadirWriter
  20. from funasr.register import tables
  21. from funasr.models.ctc.ctc import CTC
  22. @tables.register("model_classes", "Paraformer")
  23. class Paraformer(nn.Module):
  24. """
  25. Author: Speech Lab of DAMO Academy, Alibaba Group
  26. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  27. https://arxiv.org/abs/2206.08317
  28. """
  29. def __init__(
  30. self,
  31. # token_list: Union[Tuple[str, ...], List[str]],
  32. specaug: Optional[str] = None,
  33. specaug_conf: Optional[Dict] = None,
  34. normalize: str = None,
  35. normalize_conf: Optional[Dict] = None,
  36. encoder: str = None,
  37. encoder_conf: Optional[Dict] = None,
  38. decoder: str = None,
  39. decoder_conf: Optional[Dict] = None,
  40. ctc: str = None,
  41. ctc_conf: Optional[Dict] = None,
  42. predictor: str = None,
  43. predictor_conf: Optional[Dict] = None,
  44. ctc_weight: float = 0.5,
  45. input_size: int = 80,
  46. vocab_size: int = -1,
  47. ignore_id: int = -1,
  48. blank_id: int = 0,
  49. sos: int = 1,
  50. eos: int = 2,
  51. lsm_weight: float = 0.0,
  52. length_normalized_loss: bool = False,
  53. # report_cer: bool = True,
  54. # report_wer: bool = True,
  55. # sym_space: str = "<space>",
  56. # sym_blank: str = "<blank>",
  57. # extract_feats_in_collect_stats: bool = True,
  58. # predictor=None,
  59. predictor_weight: float = 0.0,
  60. predictor_bias: int = 0,
  61. sampling_ratio: float = 0.2,
  62. share_embedding: bool = False,
  63. # preencoder: Optional[AbsPreEncoder] = None,
  64. # postencoder: Optional[AbsPostEncoder] = None,
  65. use_1st_decoder_loss: bool = False,
  66. **kwargs,
  67. ):
  68. super().__init__()
  69. if specaug is not None:
  70. specaug_class = tables.specaug_classes.get(specaug)
  71. specaug = specaug_class(**specaug_conf)
  72. if normalize is not None:
  73. normalize_class = tables.normalize_classes.get(normalize)
  74. normalize = normalize_class(**normalize_conf)
  75. encoder_class = tables.encoder_classes.get(encoder)
  76. encoder = encoder_class(input_size=input_size, **encoder_conf)
  77. encoder_output_size = encoder.output_size()
  78. if decoder is not None:
  79. decoder_class = tables.decoder_classes.get(decoder)
  80. decoder = decoder_class(
  81. vocab_size=vocab_size,
  82. encoder_output_size=encoder_output_size,
  83. **decoder_conf,
  84. )
  85. if ctc_weight > 0.0:
  86. if ctc_conf is None:
  87. ctc_conf = {}
  88. ctc = CTC(
  89. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  90. )
  91. if predictor is not None:
  92. predictor_class = tables.predictor_classes.get(predictor)
  93. predictor = predictor_class(**predictor_conf)
  94. # note that eos is the same as sos (equivalent ID)
  95. self.blank_id = blank_id
  96. self.sos = sos if sos is not None else vocab_size - 1
  97. self.eos = eos if eos is not None else vocab_size - 1
  98. self.vocab_size = vocab_size
  99. self.ignore_id = ignore_id
  100. self.ctc_weight = ctc_weight
  101. # self.token_list = token_list.copy()
  102. #
  103. # self.frontend = frontend
  104. self.specaug = specaug
  105. self.normalize = normalize
  106. # self.preencoder = preencoder
  107. # self.postencoder = postencoder
  108. self.encoder = encoder
  109. #
  110. # if not hasattr(self.encoder, "interctc_use_conditioning"):
  111. # self.encoder.interctc_use_conditioning = False
  112. # if self.encoder.interctc_use_conditioning:
  113. # self.encoder.conditioning_layer = torch.nn.Linear(
  114. # vocab_size, self.encoder.output_size()
  115. # )
  116. #
  117. # self.error_calculator = None
  118. #
  119. if ctc_weight == 1.0:
  120. self.decoder = None
  121. else:
  122. self.decoder = decoder
  123. self.criterion_att = LabelSmoothingLoss(
  124. size=vocab_size,
  125. padding_idx=ignore_id,
  126. smoothing=lsm_weight,
  127. normalize_length=length_normalized_loss,
  128. )
  129. #
  130. # if report_cer or report_wer:
  131. # self.error_calculator = ErrorCalculator(
  132. # token_list, sym_space, sym_blank, report_cer, report_wer
  133. # )
  134. #
  135. if ctc_weight == 0.0:
  136. self.ctc = None
  137. else:
  138. self.ctc = ctc
  139. #
  140. # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  141. self.predictor = predictor
  142. self.predictor_weight = predictor_weight
  143. self.predictor_bias = predictor_bias
  144. self.sampling_ratio = sampling_ratio
  145. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  146. # self.step_cur = 0
  147. #
  148. self.share_embedding = share_embedding
  149. if self.share_embedding:
  150. self.decoder.embed = None
  151. self.use_1st_decoder_loss = use_1st_decoder_loss
  152. self.length_normalized_loss = length_normalized_loss
  153. self.beam_search = None
  154. def forward(
  155. self,
  156. speech: torch.Tensor,
  157. speech_lengths: torch.Tensor,
  158. text: torch.Tensor,
  159. text_lengths: torch.Tensor,
  160. **kwargs,
  161. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  162. """Encoder + Decoder + Calc loss
  163. Args:
  164. speech: (Batch, Length, ...)
  165. speech_lengths: (Batch, )
  166. text: (Batch, Length)
  167. text_lengths: (Batch,)
  168. """
  169. # import pdb;
  170. # pdb.set_trace()
  171. if len(text_lengths.size()) > 1:
  172. text_lengths = text_lengths[:, 0]
  173. if len(speech_lengths.size()) > 1:
  174. speech_lengths = speech_lengths[:, 0]
  175. batch_size = speech.shape[0]
  176. # Encoder
  177. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  178. loss_ctc, cer_ctc = None, None
  179. loss_pre = None
  180. stats = dict()
  181. # decoder: CTC branch
  182. if self.ctc_weight != 0.0:
  183. loss_ctc, cer_ctc = self._calc_ctc_loss(
  184. encoder_out, encoder_out_lens, text, text_lengths
  185. )
  186. # Collect CTC branch stats
  187. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  188. stats["cer_ctc"] = cer_ctc
  189. # decoder: Attention decoder branch
  190. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
  191. encoder_out, encoder_out_lens, text, text_lengths
  192. )
  193. # 3. CTC-Att loss definition
  194. if self.ctc_weight == 0.0:
  195. loss = loss_att + loss_pre * self.predictor_weight
  196. else:
  197. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  198. # Collect Attn branch stats
  199. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  200. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  201. stats["acc"] = acc_att
  202. stats["cer"] = cer_att
  203. stats["wer"] = wer_att
  204. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  205. stats["loss"] = torch.clone(loss.detach())
  206. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  207. if self.length_normalized_loss:
  208. batch_size = (text_lengths + self.predictor_bias).sum()
  209. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  210. return loss, stats, weight
  211. def encode(
  212. self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
  213. ) -> Tuple[torch.Tensor, torch.Tensor]:
  214. """Encoder. Note that this method is used by asr_inference.py
  215. Args:
  216. speech: (Batch, Length, ...)
  217. speech_lengths: (Batch, )
  218. ind: int
  219. """
  220. with autocast(False):
  221. # Data augmentation
  222. if self.specaug is not None and self.training:
  223. speech, speech_lengths = self.specaug(speech, speech_lengths)
  224. # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  225. if self.normalize is not None:
  226. speech, speech_lengths = self.normalize(speech, speech_lengths)
  227. # Forward encoder
  228. encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
  229. if isinstance(encoder_out, tuple):
  230. encoder_out = encoder_out[0]
  231. return encoder_out, encoder_out_lens
  232. def calc_predictor(self, encoder_out, encoder_out_lens):
  233. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  234. encoder_out.device)
  235. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
  236. encoder_out_mask,
  237. ignore_id=self.ignore_id)
  238. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  239. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  240. decoder_outs = self.decoder(
  241. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  242. )
  243. decoder_out = decoder_outs[0]
  244. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  245. return decoder_out, ys_pad_lens
  246. def _calc_att_loss(
  247. self,
  248. encoder_out: torch.Tensor,
  249. encoder_out_lens: torch.Tensor,
  250. ys_pad: torch.Tensor,
  251. ys_pad_lens: torch.Tensor,
  252. ):
  253. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  254. encoder_out.device)
  255. if self.predictor_bias == 1:
  256. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  257. ys_pad_lens = ys_pad_lens + self.predictor_bias
  258. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  259. ignore_id=self.ignore_id)
  260. # 0. sampler
  261. decoder_out_1st = None
  262. pre_loss_att = None
  263. if self.sampling_ratio > 0.0:
  264. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  265. pre_acoustic_embeds)
  266. else:
  267. sematic_embeds = pre_acoustic_embeds
  268. # 1. Forward decoder
  269. decoder_outs = self.decoder(
  270. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  271. )
  272. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  273. if decoder_out_1st is None:
  274. decoder_out_1st = decoder_out
  275. # 2. Compute attention loss
  276. loss_att = self.criterion_att(decoder_out, ys_pad)
  277. acc_att = th_accuracy(
  278. decoder_out_1st.view(-1, self.vocab_size),
  279. ys_pad,
  280. ignore_label=self.ignore_id,
  281. )
  282. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  283. # Compute cer/wer using attention-decoder
  284. if self.training or self.error_calculator is None:
  285. cer_att, wer_att = None, None
  286. else:
  287. ys_hat = decoder_out_1st.argmax(dim=-1)
  288. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  289. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  290. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  291. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  292. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  293. if self.share_embedding:
  294. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  295. else:
  296. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  297. with torch.no_grad():
  298. decoder_outs = self.decoder(
  299. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  300. )
  301. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  302. pred_tokens = decoder_out.argmax(-1)
  303. nonpad_positions = ys_pad.ne(self.ignore_id)
  304. seq_lens = (nonpad_positions).sum(1)
  305. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  306. input_mask = torch.ones_like(nonpad_positions)
  307. bsz, seq_len = ys_pad.size()
  308. for li in range(bsz):
  309. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  310. if target_num > 0:
  311. input_mask[li].scatter_(dim=0,
  312. index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
  313. value=0)
  314. input_mask = input_mask.eq(1)
  315. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  316. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  317. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  318. input_mask_expand_dim, 0)
  319. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  320. def _calc_ctc_loss(
  321. self,
  322. encoder_out: torch.Tensor,
  323. encoder_out_lens: torch.Tensor,
  324. ys_pad: torch.Tensor,
  325. ys_pad_lens: torch.Tensor,
  326. ):
  327. # Calc CTC loss
  328. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  329. # Calc CER using CTC
  330. cer_ctc = None
  331. if not self.training and self.error_calculator is not None:
  332. ys_hat = self.ctc.argmax(encoder_out).data
  333. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  334. return loss_ctc, cer_ctc
  335. def init_beam_search(self,
  336. **kwargs,
  337. ):
  338. from funasr.models.paraformer.search import BeamSearchPara
  339. from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
  340. from funasr.models.transformer.scorers.length_bonus import LengthBonus
  341. # 1. Build ASR model
  342. scorers = {}
  343. if self.ctc != None:
  344. ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
  345. scorers.update(
  346. ctc=ctc
  347. )
  348. token_list = kwargs.get("token_list")
  349. scorers.update(
  350. length_bonus=LengthBonus(len(token_list)),
  351. )
  352. # 3. Build ngram model
  353. # ngram is not supported now
  354. ngram = None
  355. scorers["ngram"] = ngram
  356. weights = dict(
  357. decoder=1.0 - kwargs.get("decoding_ctc_weight"),
  358. ctc=kwargs.get("decoding_ctc_weight", 0.0),
  359. lm=kwargs.get("lm_weight", 0.0),
  360. ngram=kwargs.get("ngram_weight", 0.0),
  361. length_bonus=kwargs.get("penalty", 0.0),
  362. )
  363. beam_search = BeamSearchPara(
  364. beam_size=kwargs.get("beam_size", 2),
  365. weights=weights,
  366. scorers=scorers,
  367. sos=self.sos,
  368. eos=self.eos,
  369. vocab_size=len(token_list),
  370. token_list=token_list,
  371. pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
  372. )
  373. # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  374. # for scorer in scorers.values():
  375. # if isinstance(scorer, torch.nn.Module):
  376. # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
  377. self.beam_search = beam_search
  378. def generate(self,
  379. data_in,
  380. data_lengths=None,
  381. key: list=None,
  382. tokenizer=None,
  383. frontend=None,
  384. **kwargs,
  385. ):
  386. # init beamsearch
  387. is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
  388. is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
  389. if self.beam_search is None and (is_use_lm or is_use_ctc):
  390. logging.info("enable beam_search")
  391. self.init_beam_search(**kwargs)
  392. self.nbest = kwargs.get("nbest", 1)
  393. meta_data = {}
  394. if isinstance(data_in, torch.Tensor): # fbank
  395. speech, speech_lengths = data_in, data_lengths
  396. if len(speech.shape) < 3:
  397. speech = speech[None, :, :]
  398. if speech_lengths is None:
  399. speech_lengths = speech.shape[1]
  400. else:
  401. # extract fbank feats
  402. time1 = time.perf_counter()
  403. audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
  404. time2 = time.perf_counter()
  405. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  406. speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
  407. time3 = time.perf_counter()
  408. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  409. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  410. speech = speech.to(device=kwargs["device"])
  411. speech_lengths = speech_lengths.to(device=kwargs["device"])
  412. # Encoder
  413. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  414. if isinstance(encoder_out, tuple):
  415. encoder_out = encoder_out[0]
  416. # predictor
  417. predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
  418. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
  419. predictor_outs[2], predictor_outs[3]
  420. pre_token_length = pre_token_length.round().long()
  421. if torch.max(pre_token_length) < 1:
  422. return []
  423. decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
  424. pre_token_length)
  425. decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
  426. results = []
  427. b, n, d = decoder_out.size()
  428. if isinstance(key[0], (list, tuple)):
  429. key = key[0]
  430. for i in range(b):
  431. x = encoder_out[i, :encoder_out_lens[i], :]
  432. am_scores = decoder_out[i, :pre_token_length[i], :]
  433. if self.beam_search is not None:
  434. nbest_hyps = self.beam_search(
  435. x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
  436. )
  437. nbest_hyps = nbest_hyps[: self.nbest]
  438. else:
  439. yseq = am_scores.argmax(dim=-1)
  440. score = am_scores.max(dim=-1)[0]
  441. score = torch.sum(score, dim=-1)
  442. # pad with mask tokens to ensure compatibility with sos/eos tokens
  443. yseq = torch.tensor(
  444. [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
  445. )
  446. nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
  447. for nbest_idx, hyp in enumerate(nbest_hyps):
  448. ibest_writer = None
  449. if ibest_writer is None and kwargs.get("output_dir") is not None:
  450. writer = DatadirWriter(kwargs.get("output_dir"))
  451. ibest_writer = writer[f"{nbest_idx+1}best_recog"]
  452. # remove sos/eos and get results
  453. last_pos = -1
  454. if isinstance(hyp.yseq, list):
  455. token_int = hyp.yseq[1:last_pos]
  456. else:
  457. token_int = hyp.yseq[1:last_pos].tolist()
  458. # remove blank symbol id, which is assumed to be 0
  459. token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
  460. if tokenizer is not None:
  461. # Change integer-ids to tokens
  462. token = tokenizer.ids2tokens(token_int)
  463. text = tokenizer.tokens2text(token)
  464. text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
  465. result_i = {"key": key[i], "text": text_postprocessed}
  466. if ibest_writer is not None:
  467. ibest_writer["token"][key[i]] = " ".join(token)
  468. # ibest_writer["text"][key[i]] = text
  469. ibest_writer["text"][key[i]] = text_postprocessed
  470. else:
  471. result_i = {"key": key[i], "token_int": token_int}
  472. results.append(result_i)
  473. return results, meta_data