paraformer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import torch.nn as nn
  11. import random
  12. import numpy as np
  13. # from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. # from funasr.models.ctc import CTC
  18. # from funasr.models.decoder.abs_decoder import AbsDecoder
  19. # from funasr.models.e2e_asr_common import ErrorCalculator
  20. # from funasr.models.encoder.abs_encoder import AbsEncoder
  21. # from funasr.models.frontend.abs_frontend import AbsFrontend
  22. # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  23. from funasr.models.predictor.cif import mae_loss
  24. # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  25. # from funasr.models.specaug.abs_specaug import AbsSpecAug
  26. from funasr.modules.add_sos_eos import add_sos_eos
  27. from funasr.modules.nets_utils import make_pad_mask, pad_list
  28. from funasr.modules.nets_utils import th_accuracy
  29. from funasr.torch_utils.device_funcs import force_gatherable
  30. # from funasr.models.base_model import FunASRModel
  31. # from funasr.models.predictor.cif import CifPredictorV3
  32. from funasr.cli.model_class_factory import *
  33. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  34. from torch.cuda.amp import autocast
  35. else:
  36. # Nothing to do if torch<1.6.0
  37. @contextmanager
  38. def autocast(enabled=True):
  39. yield
  40. class Paraformer(nn.Module):
  41. """
  42. Author: Speech Lab of DAMO Academy, Alibaba Group
  43. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  44. https://arxiv.org/abs/2206.08317
  45. """
  46. def __init__(
  47. self,
  48. # token_list: Union[Tuple[str, ...], List[str]],
  49. frontend: Optional[str] = None,
  50. frontend_conf: Optional[Dict] = None,
  51. specaug: Optional[str] = None,
  52. specaug_conf: Optional[Dict] = None,
  53. normalize: str = None,
  54. normalize_conf: Optional[Dict] = None,
  55. encoder: str = None,
  56. encoder_conf: Optional[Dict] = None,
  57. decoder: str = None,
  58. decoder_conf: Optional[Dict] = None,
  59. ctc: str = None,
  60. ctc_conf: Optional[Dict] = None,
  61. predictor: str = None,
  62. predictor_conf: Optional[Dict] = None,
  63. ctc_weight: float = 0.5,
  64. interctc_weight: float = 0.0,
  65. input_size: int = 80,
  66. vocab_size: int = -1,
  67. ignore_id: int = -1,
  68. blank_id: int = 0,
  69. sos: int = 1,
  70. eos: int = 2,
  71. lsm_weight: float = 0.0,
  72. length_normalized_loss: bool = False,
  73. # report_cer: bool = True,
  74. # report_wer: bool = True,
  75. # sym_space: str = "<space>",
  76. # sym_blank: str = "<blank>",
  77. # extract_feats_in_collect_stats: bool = True,
  78. # predictor=None,
  79. predictor_weight: float = 0.0,
  80. predictor_bias: int = 0,
  81. sampling_ratio: float = 0.2,
  82. share_embedding: bool = False,
  83. # preencoder: Optional[AbsPreEncoder] = None,
  84. # postencoder: Optional[AbsPostEncoder] = None,
  85. use_1st_decoder_loss: bool = False,
  86. **kwargs,
  87. ):
  88. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  89. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  90. super().__init__()
  91. # import pdb;
  92. # pdb.set_trace()
  93. if frontend is not None:
  94. frontend_class = frontend_choices.get_class(frontend)
  95. frontend = frontend_class(**frontend_conf)
  96. if specaug is not None:
  97. specaug_class = specaug_choices.get_class(specaug)
  98. specaug = specaug_class(**specaug_conf)
  99. if normalize is not None:
  100. normalize_class = normalize_choices.get_class(normalize)
  101. normalize = normalize_class(**normalize_conf)
  102. encoder_class = encoder_choices.get_class(encoder)
  103. encoder = encoder_class(input_size=input_size, **encoder_conf)
  104. encoder_output_size = encoder.output_size()
  105. if decoder is not None:
  106. decoder_class = decoder_choices.get_class(decoder)
  107. decoder = decoder_class(
  108. vocab_size=vocab_size,
  109. encoder_output_size=encoder_output_size,
  110. **decoder_conf,
  111. )
  112. if ctc_weight > 0.0:
  113. if ctc_conf is None:
  114. ctc_conf = {}
  115. ctc = CTC(
  116. odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
  117. )
  118. if predictor is not None:
  119. predictor_class = predictor_choices.get_class(predictor)
  120. predictor = predictor_class(**predictor_conf)
  121. # note that eos is the same as sos (equivalent ID)
  122. self.blank_id = blank_id
  123. self.sos = sos if sos is not None else vocab_size - 1
  124. self.eos = eos if eos is not None else vocab_size - 1
  125. self.vocab_size = vocab_size
  126. self.ignore_id = ignore_id
  127. self.ctc_weight = ctc_weight
  128. self.interctc_weight = interctc_weight
  129. # self.token_list = token_list.copy()
  130. #
  131. self.frontend = frontend
  132. self.specaug = specaug
  133. self.normalize = normalize
  134. # self.preencoder = preencoder
  135. # self.postencoder = postencoder
  136. self.encoder = encoder
  137. #
  138. # if not hasattr(self.encoder, "interctc_use_conditioning"):
  139. # self.encoder.interctc_use_conditioning = False
  140. # if self.encoder.interctc_use_conditioning:
  141. # self.encoder.conditioning_layer = torch.nn.Linear(
  142. # vocab_size, self.encoder.output_size()
  143. # )
  144. #
  145. # self.error_calculator = None
  146. #
  147. if ctc_weight == 1.0:
  148. self.decoder = None
  149. else:
  150. self.decoder = decoder
  151. self.criterion_att = LabelSmoothingLoss(
  152. size=vocab_size,
  153. padding_idx=ignore_id,
  154. smoothing=lsm_weight,
  155. normalize_length=length_normalized_loss,
  156. )
  157. #
  158. # if report_cer or report_wer:
  159. # self.error_calculator = ErrorCalculator(
  160. # token_list, sym_space, sym_blank, report_cer, report_wer
  161. # )
  162. #
  163. if ctc_weight == 0.0:
  164. self.ctc = None
  165. else:
  166. self.ctc = ctc
  167. #
  168. # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  169. self.predictor = predictor
  170. self.predictor_weight = predictor_weight
  171. self.predictor_bias = predictor_bias
  172. self.sampling_ratio = sampling_ratio
  173. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  174. # self.step_cur = 0
  175. #
  176. self.share_embedding = share_embedding
  177. if self.share_embedding:
  178. self.decoder.embed = None
  179. self.use_1st_decoder_loss = use_1st_decoder_loss
  180. def forward(
  181. self,
  182. speech: torch.Tensor,
  183. speech_lengths: torch.Tensor,
  184. text: torch.Tensor,
  185. text_lengths: torch.Tensor,
  186. **kwargs,
  187. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  188. """Frontend + Encoder + Decoder + Calc loss
  189. Args:
  190. speech: (Batch, Length, ...)
  191. speech_lengths: (Batch, )
  192. text: (Batch, Length)
  193. text_lengths: (Batch,)
  194. decoding_ind: int
  195. """
  196. decoding_ind = kwargs.get("kwargs", None)
  197. # import pdb;
  198. # pdb.set_trace()
  199. if len(text_lengths.size()) > 1:
  200. text_lengths = text_lengths[:, 0]
  201. if len(speech_lengths.size()) > 1:
  202. speech_lengths = speech_lengths[:, 0]
  203. batch_size = speech.shape[0]
  204. # # for data-parallel
  205. # text = text[:, : text_lengths.max()]
  206. # speech = speech[:, :speech_lengths.max()]
  207. # 1. Encoder
  208. if hasattr(self.encoder, "overlap_chunk_cls"):
  209. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  210. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  211. else:
  212. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  213. intermediate_outs = None
  214. if isinstance(encoder_out, tuple):
  215. intermediate_outs = encoder_out[1]
  216. encoder_out = encoder_out[0]
  217. loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
  218. loss_ctc, cer_ctc = None, None
  219. loss_pre = None
  220. stats = dict()
  221. # 1. CTC branch
  222. if self.ctc_weight != 0.0:
  223. loss_ctc, cer_ctc = self._calc_ctc_loss(
  224. encoder_out, encoder_out_lens, text, text_lengths
  225. )
  226. # Collect CTC branch stats
  227. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  228. stats["cer_ctc"] = cer_ctc
  229. # Intermediate CTC (optional)
  230. loss_interctc = 0.0
  231. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  232. for layer_idx, intermediate_out in intermediate_outs:
  233. # we assume intermediate_out has the same length & padding
  234. # as those of encoder_out
  235. loss_ic, cer_ic = self._calc_ctc_loss(
  236. intermediate_out, encoder_out_lens, text, text_lengths
  237. )
  238. loss_interctc = loss_interctc + loss_ic
  239. # Collect Intermedaite CTC stats
  240. stats["loss_interctc_layer{}".format(layer_idx)] = (
  241. loss_ic.detach() if loss_ic is not None else None
  242. )
  243. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  244. loss_interctc = loss_interctc / len(intermediate_outs)
  245. # calculate whole encoder loss
  246. loss_ctc = (
  247. 1 - self.interctc_weight
  248. ) * loss_ctc + self.interctc_weight * loss_interctc
  249. # 2b. Attention decoder branch
  250. if self.ctc_weight != 1.0:
  251. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
  252. encoder_out, encoder_out_lens, text, text_lengths
  253. )
  254. # 3. CTC-Att loss definition
  255. if self.ctc_weight == 0.0:
  256. loss = loss_att + loss_pre * self.predictor_weight
  257. elif self.ctc_weight == 1.0:
  258. loss = loss_ctc
  259. else:
  260. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  261. if self.use_1st_decoder_loss and pre_loss_att is not None:
  262. loss = loss + (1 - self.ctc_weight) * pre_loss_att
  263. # Collect Attn branch stats
  264. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  265. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  266. stats["acc"] = acc_att
  267. stats["cer"] = cer_att
  268. stats["wer"] = wer_att
  269. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  270. stats["loss"] = torch.clone(loss.detach())
  271. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  272. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  273. return loss, stats, weight
  274. def collect_feats(
  275. self,
  276. speech: torch.Tensor,
  277. speech_lengths: torch.Tensor,
  278. text: torch.Tensor,
  279. text_lengths: torch.Tensor,
  280. ) -> Dict[str, torch.Tensor]:
  281. if self.extract_feats_in_collect_stats:
  282. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  283. else:
  284. # Generate dummy stats if extract_feats_in_collect_stats is False
  285. logging.warning(
  286. "Generating dummy stats for feats and feats_lengths, "
  287. "because encoder_conf.extract_feats_in_collect_stats is "
  288. f"{self.extract_feats_in_collect_stats}"
  289. )
  290. feats, feats_lengths = speech, speech_lengths
  291. return {"feats": feats, "feats_lengths": feats_lengths}
  292. def encode(
  293. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  294. ) -> Tuple[torch.Tensor, torch.Tensor]:
  295. """Frontend + Encoder. Note that this method is used by asr_inference.py
  296. Args:
  297. speech: (Batch, Length, ...)
  298. speech_lengths: (Batch, )
  299. ind: int
  300. """
  301. with autocast(False):
  302. # # 1. Extract feats
  303. # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  304. # 2. Data augmentation
  305. if self.specaug is not None and self.training:
  306. feats, feats_lengths = self.specaug(speech, speech_lengths)
  307. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  308. if self.normalize is not None:
  309. feats, feats_lengths = self.normalize(feats, feats_lengths)
  310. # # Pre-encoder, e.g. used for raw input data
  311. # if self.preencoder is not None:
  312. # feats, feats_lengths = self.preencoder(feats, feats_lengths)
  313. # 4. Forward encoder
  314. # feats: (Batch, Length, Dim)
  315. # -> encoder_out: (Batch, Length2, Dim2)
  316. if self.encoder.interctc_use_conditioning:
  317. if hasattr(self.encoder, "overlap_chunk_cls"):
  318. encoder_out, encoder_out_lens, _ = self.encoder(
  319. feats, feats_lengths, ctc=self.ctc, ind=ind
  320. )
  321. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  322. encoder_out_lens,
  323. chunk_outs=None)
  324. else:
  325. encoder_out, encoder_out_lens, _ = self.encoder(
  326. feats, feats_lengths, ctc=self.ctc
  327. )
  328. else:
  329. if hasattr(self.encoder, "overlap_chunk_cls"):
  330. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  331. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  332. encoder_out_lens,
  333. chunk_outs=None)
  334. else:
  335. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  336. intermediate_outs = None
  337. if isinstance(encoder_out, tuple):
  338. intermediate_outs = encoder_out[1]
  339. encoder_out = encoder_out[0]
  340. # # Post-encoder, e.g. NLU
  341. # if self.postencoder is not None:
  342. # encoder_out, encoder_out_lens = self.postencoder(
  343. # encoder_out, encoder_out_lens
  344. # )
  345. assert encoder_out.size(0) == speech.size(0), (
  346. encoder_out.size(),
  347. speech.size(0),
  348. )
  349. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  350. encoder_out.size(),
  351. encoder_out_lens.max(),
  352. )
  353. if intermediate_outs is not None:
  354. return (encoder_out, intermediate_outs), encoder_out_lens
  355. return encoder_out, encoder_out_lens
  356. def calc_predictor(self, encoder_out, encoder_out_lens):
  357. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  358. encoder_out.device)
  359. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
  360. ignore_id=self.ignore_id)
  361. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  362. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  363. decoder_outs = self.decoder(
  364. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  365. )
  366. decoder_out = decoder_outs[0]
  367. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  368. return decoder_out, ys_pad_lens
  369. def _extract_feats(
  370. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  371. ) -> Tuple[torch.Tensor, torch.Tensor]:
  372. assert speech_lengths.dim() == 1, speech_lengths.shape
  373. # for data-parallel
  374. speech = speech[:, : speech_lengths.max()]
  375. if self.frontend is not None:
  376. # Frontend
  377. # e.g. STFT and Feature extract
  378. # data_loader may send time-domain signal in this case
  379. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  380. feats, feats_lengths = self.frontend(speech, speech_lengths)
  381. else:
  382. # No frontend and no feature extract
  383. feats, feats_lengths = speech, speech_lengths
  384. return feats, feats_lengths
  385. def nll(
  386. self,
  387. encoder_out: torch.Tensor,
  388. encoder_out_lens: torch.Tensor,
  389. ys_pad: torch.Tensor,
  390. ys_pad_lens: torch.Tensor,
  391. ) -> torch.Tensor:
  392. """Compute negative log likelihood(nll) from transformer-decoder
  393. Normally, this function is called in batchify_nll.
  394. Args:
  395. encoder_out: (Batch, Length, Dim)
  396. encoder_out_lens: (Batch,)
  397. ys_pad: (Batch, Length)
  398. ys_pad_lens: (Batch,)
  399. """
  400. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  401. ys_in_lens = ys_pad_lens + 1
  402. # 1. Forward decoder
  403. decoder_out, _ = self.decoder(
  404. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  405. ) # [batch, seqlen, dim]
  406. batch_size = decoder_out.size(0)
  407. decoder_num_class = decoder_out.size(2)
  408. # nll: negative log-likelihood
  409. nll = torch.nn.functional.cross_entropy(
  410. decoder_out.view(-1, decoder_num_class),
  411. ys_out_pad.view(-1),
  412. ignore_index=self.ignore_id,
  413. reduction="none",
  414. )
  415. nll = nll.view(batch_size, -1)
  416. nll = nll.sum(dim=1)
  417. assert nll.size(0) == batch_size
  418. return nll
  419. def batchify_nll(
  420. self,
  421. encoder_out: torch.Tensor,
  422. encoder_out_lens: torch.Tensor,
  423. ys_pad: torch.Tensor,
  424. ys_pad_lens: torch.Tensor,
  425. batch_size: int = 100,
  426. ):
  427. """Compute negative log likelihood(nll) from transformer-decoder
  428. To avoid OOM, this fuction seperate the input into batches.
  429. Then call nll for each batch and combine and return results.
  430. Args:
  431. encoder_out: (Batch, Length, Dim)
  432. encoder_out_lens: (Batch,)
  433. ys_pad: (Batch, Length)
  434. ys_pad_lens: (Batch,)
  435. batch_size: int, samples each batch contain when computing nll,
  436. you may change this to avoid OOM or increase
  437. GPU memory usage
  438. """
  439. total_num = encoder_out.size(0)
  440. if total_num <= batch_size:
  441. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  442. else:
  443. nll = []
  444. start_idx = 0
  445. while True:
  446. end_idx = min(start_idx + batch_size, total_num)
  447. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  448. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  449. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  450. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  451. batch_nll = self.nll(
  452. batch_encoder_out,
  453. batch_encoder_out_lens,
  454. batch_ys_pad,
  455. batch_ys_pad_lens,
  456. )
  457. nll.append(batch_nll)
  458. start_idx = end_idx
  459. if start_idx == total_num:
  460. break
  461. nll = torch.cat(nll)
  462. assert nll.size(0) == total_num
  463. return nll
  464. def _calc_att_loss(
  465. self,
  466. encoder_out: torch.Tensor,
  467. encoder_out_lens: torch.Tensor,
  468. ys_pad: torch.Tensor,
  469. ys_pad_lens: torch.Tensor,
  470. ):
  471. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  472. encoder_out.device)
  473. if self.predictor_bias == 1:
  474. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  475. ys_pad_lens = ys_pad_lens + self.predictor_bias
  476. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  477. ignore_id=self.ignore_id)
  478. # 0. sampler
  479. decoder_out_1st = None
  480. pre_loss_att = None
  481. if self.sampling_ratio > 0.0:
  482. if self.use_1st_decoder_loss:
  483. sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  484. pre_acoustic_embeds)
  485. else:
  486. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  487. pre_acoustic_embeds)
  488. else:
  489. if self.step_cur < 2:
  490. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  491. sematic_embeds = pre_acoustic_embeds
  492. # 1. Forward decoder
  493. decoder_outs = self.decoder(
  494. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  495. )
  496. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  497. if decoder_out_1st is None:
  498. decoder_out_1st = decoder_out
  499. # 2. Compute attention loss
  500. loss_att = self.criterion_att(decoder_out, ys_pad)
  501. acc_att = th_accuracy(
  502. decoder_out_1st.view(-1, self.vocab_size),
  503. ys_pad,
  504. ignore_label=self.ignore_id,
  505. )
  506. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  507. # Compute cer/wer using attention-decoder
  508. if self.training or self.error_calculator is None:
  509. cer_att, wer_att = None, None
  510. else:
  511. ys_hat = decoder_out_1st.argmax(dim=-1)
  512. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  513. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  514. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  515. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  516. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  517. if self.share_embedding:
  518. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  519. else:
  520. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  521. with torch.no_grad():
  522. decoder_outs = self.decoder(
  523. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  524. )
  525. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  526. pred_tokens = decoder_out.argmax(-1)
  527. nonpad_positions = ys_pad.ne(self.ignore_id)
  528. seq_lens = (nonpad_positions).sum(1)
  529. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  530. input_mask = torch.ones_like(nonpad_positions)
  531. bsz, seq_len = ys_pad.size()
  532. for li in range(bsz):
  533. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  534. if target_num > 0:
  535. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
  536. input_mask = input_mask.eq(1)
  537. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  538. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  539. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  540. input_mask_expand_dim, 0)
  541. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  542. def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  543. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  544. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  545. if self.share_embedding:
  546. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  547. else:
  548. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  549. decoder_outs = self.decoder(
  550. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  551. )
  552. pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
  553. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  554. pred_tokens = decoder_out.argmax(-1)
  555. nonpad_positions = ys_pad.ne(self.ignore_id)
  556. seq_lens = (nonpad_positions).sum(1)
  557. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  558. input_mask = torch.ones_like(nonpad_positions)
  559. bsz, seq_len = ys_pad.size()
  560. for li in range(bsz):
  561. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  562. if target_num > 0:
  563. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), value=0)
  564. input_mask = input_mask.eq(1)
  565. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  566. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  567. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  568. input_mask_expand_dim, 0)
  569. return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
  570. def _calc_ctc_loss(
  571. self,
  572. encoder_out: torch.Tensor,
  573. encoder_out_lens: torch.Tensor,
  574. ys_pad: torch.Tensor,
  575. ys_pad_lens: torch.Tensor,
  576. ):
  577. # Calc CTC loss
  578. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  579. # Calc CER using CTC
  580. cer_ctc = None
  581. if not self.training and self.error_calculator is not None:
  582. ys_hat = self.ctc.argmax(encoder_out).data
  583. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  584. return loss_ctc, cer_ctc