e2e_uni_asr.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from funasr.models.e2e_asr_common import ErrorCalculator
  11. from funasr.modules.nets_utils import th_accuracy
  12. from funasr.modules.add_sos_eos import add_sos_eos
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.encoder.abs_encoder import AbsEncoder
  19. from funasr.models.frontend.abs_frontend import AbsFrontend
  20. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  21. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  22. from funasr.models.specaug.abs_specaug import AbsSpecAug
  23. from funasr.layers.abs_normalize import AbsNormalize
  24. from funasr.torch_utils.device_funcs import force_gatherable
  25. from funasr.models.base_model import FunASRModel
  26. from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
  27. from funasr.models.predictor.cif import mae_loss
  28. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  29. from torch.cuda.amp import autocast
  30. else:
  31. # Nothing to do if torch<1.6.0
  32. @contextmanager
  33. def autocast(enabled=True):
  34. yield
  35. class UniASR(FunASRModel):
  36. """
  37. Author: Speech Lab of DAMO Academy, Alibaba Group
  38. """
  39. def __init__(
  40. self,
  41. vocab_size: int,
  42. token_list: Union[Tuple[str, ...], List[str]],
  43. frontend: Optional[AbsFrontend],
  44. specaug: Optional[AbsSpecAug],
  45. normalize: Optional[AbsNormalize],
  46. encoder: AbsEncoder,
  47. decoder: AbsDecoder,
  48. ctc: CTC,
  49. ctc_weight: float = 0.5,
  50. interctc_weight: float = 0.0,
  51. ignore_id: int = -1,
  52. lsm_weight: float = 0.0,
  53. length_normalized_loss: bool = False,
  54. report_cer: bool = True,
  55. report_wer: bool = True,
  56. sym_space: str = "<space>",
  57. sym_blank: str = "<blank>",
  58. extract_feats_in_collect_stats: bool = True,
  59. predictor=None,
  60. predictor_weight: float = 0.0,
  61. decoder_attention_chunk_type: str = 'chunk',
  62. encoder2: AbsEncoder = None,
  63. decoder2: AbsDecoder = None,
  64. ctc2: CTC = None,
  65. ctc_weight2: float = 0.5,
  66. interctc_weight2: float = 0.0,
  67. predictor2=None,
  68. predictor_weight2: float = 0.0,
  69. decoder_attention_chunk_type2: str = 'chunk',
  70. stride_conv=None,
  71. loss_weight_model1: float = 0.5,
  72. enable_maas_finetune: bool = False,
  73. freeze_encoder2: bool = False,
  74. preencoder: Optional[AbsPreEncoder] = None,
  75. postencoder: Optional[AbsPostEncoder] = None,
  76. encoder1_encoder2_joint_training: bool = True,
  77. ):
  78. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  79. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  80. super().__init__()
  81. self.blank_id = 0
  82. self.sos = 1
  83. self.eos = 2
  84. self.vocab_size = vocab_size
  85. self.ignore_id = ignore_id
  86. self.ctc_weight = ctc_weight
  87. self.interctc_weight = interctc_weight
  88. self.token_list = token_list.copy()
  89. self.frontend = frontend
  90. self.specaug = specaug
  91. self.normalize = normalize
  92. self.preencoder = preencoder
  93. self.postencoder = postencoder
  94. self.encoder = encoder
  95. if not hasattr(self.encoder, "interctc_use_conditioning"):
  96. self.encoder.interctc_use_conditioning = False
  97. if self.encoder.interctc_use_conditioning:
  98. self.encoder.conditioning_layer = torch.nn.Linear(
  99. vocab_size, self.encoder.output_size()
  100. )
  101. self.error_calculator = None
  102. # we set self.decoder = None in the CTC mode since
  103. # self.decoder parameters were never used and PyTorch complained
  104. # and threw an Exception in the multi-GPU experiment.
  105. # thanks Jeff Farris for pointing out the issue.
  106. if ctc_weight == 1.0:
  107. self.decoder = None
  108. else:
  109. self.decoder = decoder
  110. self.criterion_att = LabelSmoothingLoss(
  111. size=vocab_size,
  112. padding_idx=ignore_id,
  113. smoothing=lsm_weight,
  114. normalize_length=length_normalized_loss,
  115. )
  116. if report_cer or report_wer:
  117. self.error_calculator = ErrorCalculator(
  118. token_list, sym_space, sym_blank, report_cer, report_wer
  119. )
  120. if ctc_weight == 0.0:
  121. self.ctc = None
  122. else:
  123. self.ctc = ctc
  124. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  125. self.predictor = predictor
  126. self.predictor_weight = predictor_weight
  127. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  128. self.step_cur = 0
  129. if self.encoder.overlap_chunk_cls is not None:
  130. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  131. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  132. self.decoder_attention_chunk_type = decoder_attention_chunk_type
  133. self.encoder2 = encoder2
  134. self.decoder2 = decoder2
  135. self.ctc_weight2 = ctc_weight2
  136. if ctc_weight2 == 0.0:
  137. self.ctc2 = None
  138. else:
  139. self.ctc2 = ctc2
  140. self.interctc_weight2 = interctc_weight2
  141. self.predictor2 = predictor2
  142. self.predictor_weight2 = predictor_weight2
  143. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  144. self.stride_conv = stride_conv
  145. self.loss_weight_model1 = loss_weight_model1
  146. if self.encoder2.overlap_chunk_cls is not None:
  147. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  148. self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
  149. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  150. self.enable_maas_finetune = enable_maas_finetune
  151. self.freeze_encoder2 = freeze_encoder2
  152. self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
  153. self.length_normalized_loss = length_normalized_loss
  154. def forward(
  155. self,
  156. speech: torch.Tensor,
  157. speech_lengths: torch.Tensor,
  158. text: torch.Tensor,
  159. text_lengths: torch.Tensor,
  160. decoding_ind: int = None,
  161. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  162. """Frontend + Encoder + Decoder + Calc loss
  163. Args:
  164. speech: (Batch, Length, ...)
  165. speech_lengths: (Batch, )
  166. text: (Batch, Length)
  167. text_lengths: (Batch,)
  168. """
  169. assert text_lengths.dim() == 1, text_lengths.shape
  170. # Check that batch_size is unified
  171. assert (
  172. speech.shape[0]
  173. == speech_lengths.shape[0]
  174. == text.shape[0]
  175. == text_lengths.shape[0]
  176. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  177. batch_size = speech.shape[0]
  178. # for data-parallel
  179. text = text[:, : text_lengths.max()]
  180. speech = speech[:, :speech_lengths.max()]
  181. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  182. # 1. Encoder
  183. if self.enable_maas_finetune:
  184. with torch.no_grad():
  185. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  186. else:
  187. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  188. intermediate_outs = None
  189. if isinstance(encoder_out, tuple):
  190. intermediate_outs = encoder_out[1]
  191. encoder_out = encoder_out[0]
  192. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  193. loss_ctc, cer_ctc = None, None
  194. stats = dict()
  195. loss_pre = None
  196. loss, loss1, loss2 = 0.0, 0.0, 0.0
  197. if self.loss_weight_model1 > 0.0:
  198. ## model1
  199. # 1. CTC branch
  200. if self.enable_maas_finetune:
  201. with torch.no_grad():
  202. if self.ctc_weight != 0.0:
  203. if self.encoder.overlap_chunk_cls is not None:
  204. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  205. encoder_out_lens,
  206. chunk_outs=None)
  207. loss_ctc, cer_ctc = self._calc_ctc_loss(
  208. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  209. )
  210. # Collect CTC branch stats
  211. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  212. stats["cer_ctc"] = cer_ctc
  213. # Intermediate CTC (optional)
  214. loss_interctc = 0.0
  215. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  216. for layer_idx, intermediate_out in intermediate_outs:
  217. # we assume intermediate_out has the same length & padding
  218. # as those of encoder_out
  219. if self.encoder.overlap_chunk_cls is not None:
  220. encoder_out_ctc, encoder_out_lens_ctc = \
  221. self.encoder.overlap_chunk_cls.remove_chunk(
  222. intermediate_out,
  223. encoder_out_lens,
  224. chunk_outs=None)
  225. loss_ic, cer_ic = self._calc_ctc_loss(
  226. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  227. )
  228. loss_interctc = loss_interctc + loss_ic
  229. # Collect Intermedaite CTC stats
  230. stats["loss_interctc_layer{}".format(layer_idx)] = (
  231. loss_ic.detach() if loss_ic is not None else None
  232. )
  233. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  234. loss_interctc = loss_interctc / len(intermediate_outs)
  235. # calculate whole encoder loss
  236. loss_ctc = (
  237. 1 - self.interctc_weight
  238. ) * loss_ctc + self.interctc_weight * loss_interctc
  239. # 2b. Attention decoder branch
  240. if self.ctc_weight != 1.0:
  241. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  242. encoder_out, encoder_out_lens, text, text_lengths
  243. )
  244. # 3. CTC-Att loss definition
  245. if self.ctc_weight == 0.0:
  246. loss = loss_att + loss_pre * self.predictor_weight
  247. elif self.ctc_weight == 1.0:
  248. loss = loss_ctc
  249. else:
  250. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  251. # Collect Attn branch stats
  252. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  253. stats["acc"] = acc_att
  254. stats["cer"] = cer_att
  255. stats["wer"] = wer_att
  256. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  257. else:
  258. if self.ctc_weight != 0.0:
  259. if self.encoder.overlap_chunk_cls is not None:
  260. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  261. encoder_out_lens,
  262. chunk_outs=None)
  263. loss_ctc, cer_ctc = self._calc_ctc_loss(
  264. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  265. )
  266. # Collect CTC branch stats
  267. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  268. stats["cer_ctc"] = cer_ctc
  269. # Intermediate CTC (optional)
  270. loss_interctc = 0.0
  271. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  272. for layer_idx, intermediate_out in intermediate_outs:
  273. # we assume intermediate_out has the same length & padding
  274. # as those of encoder_out
  275. if self.encoder.overlap_chunk_cls is not None:
  276. encoder_out_ctc, encoder_out_lens_ctc = \
  277. self.encoder.overlap_chunk_cls.remove_chunk(
  278. intermediate_out,
  279. encoder_out_lens,
  280. chunk_outs=None)
  281. loss_ic, cer_ic = self._calc_ctc_loss(
  282. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  283. )
  284. loss_interctc = loss_interctc + loss_ic
  285. # Collect Intermedaite CTC stats
  286. stats["loss_interctc_layer{}".format(layer_idx)] = (
  287. loss_ic.detach() if loss_ic is not None else None
  288. )
  289. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  290. loss_interctc = loss_interctc / len(intermediate_outs)
  291. # calculate whole encoder loss
  292. loss_ctc = (
  293. 1 - self.interctc_weight
  294. ) * loss_ctc + self.interctc_weight * loss_interctc
  295. # 2b. Attention decoder branch
  296. if self.ctc_weight != 1.0:
  297. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  298. encoder_out, encoder_out_lens, text, text_lengths
  299. )
  300. # 3. CTC-Att loss definition
  301. if self.ctc_weight == 0.0:
  302. loss = loss_att + loss_pre * self.predictor_weight
  303. elif self.ctc_weight == 1.0:
  304. loss = loss_ctc
  305. else:
  306. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  307. # Collect Attn branch stats
  308. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  309. stats["acc"] = acc_att
  310. stats["cer"] = cer_att
  311. stats["wer"] = wer_att
  312. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  313. loss1 = loss
  314. if self.loss_weight_model1 < 1.0:
  315. ## model2
  316. # encoder2
  317. if self.freeze_encoder2:
  318. with torch.no_grad():
  319. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  320. else:
  321. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  322. intermediate_outs = None
  323. if isinstance(encoder_out, tuple):
  324. intermediate_outs = encoder_out[1]
  325. encoder_out = encoder_out[0]
  326. # CTC2
  327. if self.ctc_weight2 != 0.0:
  328. if self.encoder2.overlap_chunk_cls is not None:
  329. encoder_out_ctc, encoder_out_lens_ctc = \
  330. self.encoder2.overlap_chunk_cls.remove_chunk(
  331. encoder_out,
  332. encoder_out_lens,
  333. chunk_outs=None,
  334. )
  335. loss_ctc, cer_ctc = self._calc_ctc_loss2(
  336. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  337. )
  338. # Collect CTC branch stats
  339. stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
  340. stats["cer_ctc2"] = cer_ctc
  341. # Intermediate CTC (optional)
  342. loss_interctc = 0.0
  343. if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
  344. for layer_idx, intermediate_out in intermediate_outs:
  345. # we assume intermediate_out has the same length & padding
  346. # as those of encoder_out
  347. if self.encoder2.overlap_chunk_cls is not None:
  348. encoder_out_ctc, encoder_out_lens_ctc = \
  349. self.encoder2.overlap_chunk_cls.remove_chunk(
  350. intermediate_out,
  351. encoder_out_lens,
  352. chunk_outs=None)
  353. loss_ic, cer_ic = self._calc_ctc_loss2(
  354. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  355. )
  356. loss_interctc = loss_interctc + loss_ic
  357. # Collect Intermedaite CTC stats
  358. stats["loss_interctc_layer{}2".format(layer_idx)] = (
  359. loss_ic.detach() if loss_ic is not None else None
  360. )
  361. stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
  362. loss_interctc = loss_interctc / len(intermediate_outs)
  363. # calculate whole encoder loss
  364. loss_ctc = (
  365. 1 - self.interctc_weight2
  366. ) * loss_ctc + self.interctc_weight2 * loss_interctc
  367. # 2b. Attention decoder branch
  368. if self.ctc_weight2 != 1.0:
  369. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
  370. encoder_out, encoder_out_lens, text, text_lengths
  371. )
  372. # 3. CTC-Att loss definition
  373. if self.ctc_weight2 == 0.0:
  374. loss = loss_att + loss_pre * self.predictor_weight2
  375. elif self.ctc_weight2 == 1.0:
  376. loss = loss_ctc
  377. else:
  378. loss = self.ctc_weight2 * loss_ctc + (
  379. 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
  380. # Collect Attn branch stats
  381. stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
  382. stats["acc2"] = acc_att
  383. stats["cer2"] = cer_att
  384. stats["wer2"] = wer_att
  385. stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
  386. loss2 = loss
  387. loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
  388. stats["loss1"] = torch.clone(loss1.detach())
  389. stats["loss2"] = torch.clone(loss2.detach())
  390. stats["loss"] = torch.clone(loss.detach())
  391. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  392. if self.length_normalized_loss:
  393. batch_size = int((text_lengths + 1).sum())
  394. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  395. return loss, stats, weight
  396. def collect_feats(
  397. self,
  398. speech: torch.Tensor,
  399. speech_lengths: torch.Tensor,
  400. text: torch.Tensor,
  401. text_lengths: torch.Tensor,
  402. ) -> Dict[str, torch.Tensor]:
  403. if self.extract_feats_in_collect_stats:
  404. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  405. else:
  406. # Generate dummy stats if extract_feats_in_collect_stats is False
  407. logging.warning(
  408. "Generating dummy stats for feats and feats_lengths, "
  409. "because encoder_conf.extract_feats_in_collect_stats is "
  410. f"{self.extract_feats_in_collect_stats}"
  411. )
  412. feats, feats_lengths = speech, speech_lengths
  413. return {"feats": feats, "feats_lengths": feats_lengths}
  414. def encode(
  415. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  416. ) -> Tuple[torch.Tensor, torch.Tensor]:
  417. """Frontend + Encoder. Note that this method is used by asr_inference.py
  418. Args:
  419. speech: (Batch, Length, ...)
  420. speech_lengths: (Batch, )
  421. """
  422. with autocast(False):
  423. # 1. Extract feats
  424. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  425. # 2. Data augmentation
  426. if self.specaug is not None and self.training:
  427. feats, feats_lengths = self.specaug(feats, feats_lengths)
  428. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  429. if self.normalize is not None:
  430. feats, feats_lengths = self.normalize(feats, feats_lengths)
  431. speech_raw = feats.clone().to(feats.device)
  432. # Pre-encoder, e.g. used for raw input data
  433. if self.preencoder is not None:
  434. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  435. # 4. Forward encoder
  436. # feats: (Batch, Length, Dim)
  437. # -> encoder_out: (Batch, Length2, Dim2)
  438. if self.encoder.interctc_use_conditioning:
  439. encoder_out, encoder_out_lens, _ = self.encoder(
  440. feats, feats_lengths, ctc=self.ctc, ind=ind
  441. )
  442. else:
  443. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  444. intermediate_outs = None
  445. if isinstance(encoder_out, tuple):
  446. intermediate_outs = encoder_out[1]
  447. encoder_out = encoder_out[0]
  448. # Post-encoder, e.g. NLU
  449. if self.postencoder is not None:
  450. encoder_out, encoder_out_lens = self.postencoder(
  451. encoder_out, encoder_out_lens
  452. )
  453. assert encoder_out.size(0) == speech.size(0), (
  454. encoder_out.size(),
  455. speech.size(0),
  456. )
  457. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  458. encoder_out.size(),
  459. encoder_out_lens.max(),
  460. )
  461. if intermediate_outs is not None:
  462. return (encoder_out, intermediate_outs), encoder_out_lens
  463. return speech_raw, encoder_out, encoder_out_lens
  464. def encode2(
  465. self,
  466. encoder_out: torch.Tensor,
  467. encoder_out_lens: torch.Tensor,
  468. speech: torch.Tensor,
  469. speech_lengths: torch.Tensor,
  470. ind: int = 0,
  471. ) -> Tuple[torch.Tensor, torch.Tensor]:
  472. """Frontend + Encoder. Note that this method is used by asr_inference.py
  473. Args:
  474. speech: (Batch, Length, ...)
  475. speech_lengths: (Batch, )
  476. """
  477. # with autocast(False):
  478. # # 1. Extract feats
  479. # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  480. #
  481. # # 2. Data augmentation
  482. # if self.specaug is not None and self.training:
  483. # feats, feats_lengths = self.specaug(feats, feats_lengths)
  484. #
  485. # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  486. # if self.normalize is not None:
  487. # feats, feats_lengths = self.normalize(feats, feats_lengths)
  488. # Pre-encoder, e.g. used for raw input data
  489. # if self.preencoder is not None:
  490. # feats, feats_lengths = self.preencoder(feats, feats_lengths)
  491. encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
  492. encoder_out,
  493. encoder_out_lens,
  494. chunk_outs=None,
  495. )
  496. # residual_input
  497. encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
  498. encoder_out_lens = encoder_out_lens_rm
  499. if self.stride_conv is not None:
  500. speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
  501. if not self.encoder1_encoder2_joint_training:
  502. speech = speech.detach()
  503. speech_lengths = speech_lengths.detach()
  504. # 4. Forward encoder
  505. # feats: (Batch, Length, Dim)
  506. # -> encoder_out: (Batch, Length2, Dim2)
  507. if self.encoder2.interctc_use_conditioning:
  508. encoder_out, encoder_out_lens, _ = self.encoder2(
  509. speech, speech_lengths, ctc=self.ctc2, ind=ind
  510. )
  511. else:
  512. encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
  513. intermediate_outs = None
  514. if isinstance(encoder_out, tuple):
  515. intermediate_outs = encoder_out[1]
  516. encoder_out = encoder_out[0]
  517. # # Post-encoder, e.g. NLU
  518. # if self.postencoder is not None:
  519. # encoder_out, encoder_out_lens = self.postencoder(
  520. # encoder_out, encoder_out_lens
  521. # )
  522. assert encoder_out.size(0) == speech.size(0), (
  523. encoder_out.size(),
  524. speech.size(0),
  525. )
  526. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  527. encoder_out.size(),
  528. encoder_out_lens.max(),
  529. )
  530. if intermediate_outs is not None:
  531. return (encoder_out, intermediate_outs), encoder_out_lens
  532. return encoder_out, encoder_out_lens
  533. def _extract_feats(
  534. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  535. ) -> Tuple[torch.Tensor, torch.Tensor]:
  536. assert speech_lengths.dim() == 1, speech_lengths.shape
  537. # for data-parallel
  538. speech = speech[:, : speech_lengths.max()]
  539. if self.frontend is not None:
  540. # Frontend
  541. # e.g. STFT and Feature extract
  542. # data_loader may send time-domain signal in this case
  543. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  544. feats, feats_lengths = self.frontend(speech, speech_lengths)
  545. else:
  546. # No frontend and no feature extract
  547. feats, feats_lengths = speech, speech_lengths
  548. return feats, feats_lengths
  549. def nll(
  550. self,
  551. encoder_out: torch.Tensor,
  552. encoder_out_lens: torch.Tensor,
  553. ys_pad: torch.Tensor,
  554. ys_pad_lens: torch.Tensor,
  555. ) -> torch.Tensor:
  556. """Compute negative log likelihood(nll) from transformer-decoder
  557. Normally, this function is called in batchify_nll.
  558. Args:
  559. encoder_out: (Batch, Length, Dim)
  560. encoder_out_lens: (Batch,)
  561. ys_pad: (Batch, Length)
  562. ys_pad_lens: (Batch,)
  563. """
  564. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  565. ys_in_lens = ys_pad_lens + 1
  566. # 1. Forward decoder
  567. decoder_out, _ = self.decoder(
  568. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  569. ) # [batch, seqlen, dim]
  570. batch_size = decoder_out.size(0)
  571. decoder_num_class = decoder_out.size(2)
  572. # nll: negative log-likelihood
  573. nll = torch.nn.functional.cross_entropy(
  574. decoder_out.view(-1, decoder_num_class),
  575. ys_out_pad.view(-1),
  576. ignore_index=self.ignore_id,
  577. reduction="none",
  578. )
  579. nll = nll.view(batch_size, -1)
  580. nll = nll.sum(dim=1)
  581. assert nll.size(0) == batch_size
  582. return nll
  583. def batchify_nll(
  584. self,
  585. encoder_out: torch.Tensor,
  586. encoder_out_lens: torch.Tensor,
  587. ys_pad: torch.Tensor,
  588. ys_pad_lens: torch.Tensor,
  589. batch_size: int = 100,
  590. ):
  591. """Compute negative log likelihood(nll) from transformer-decoder
  592. To avoid OOM, this fuction seperate the input into batches.
  593. Then call nll for each batch and combine and return results.
  594. Args:
  595. encoder_out: (Batch, Length, Dim)
  596. encoder_out_lens: (Batch,)
  597. ys_pad: (Batch, Length)
  598. ys_pad_lens: (Batch,)
  599. batch_size: int, samples each batch contain when computing nll,
  600. you may change this to avoid OOM or increase
  601. GPU memory usage
  602. """
  603. total_num = encoder_out.size(0)
  604. if total_num <= batch_size:
  605. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  606. else:
  607. nll = []
  608. start_idx = 0
  609. while True:
  610. end_idx = min(start_idx + batch_size, total_num)
  611. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  612. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  613. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  614. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  615. batch_nll = self.nll(
  616. batch_encoder_out,
  617. batch_encoder_out_lens,
  618. batch_ys_pad,
  619. batch_ys_pad_lens,
  620. )
  621. nll.append(batch_nll)
  622. start_idx = end_idx
  623. if start_idx == total_num:
  624. break
  625. nll = torch.cat(nll)
  626. assert nll.size(0) == total_num
  627. return nll
  628. def _calc_att_loss(
  629. self,
  630. encoder_out: torch.Tensor,
  631. encoder_out_lens: torch.Tensor,
  632. ys_pad: torch.Tensor,
  633. ys_pad_lens: torch.Tensor,
  634. ):
  635. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  636. ys_in_lens = ys_pad_lens + 1
  637. # 1. Forward decoder
  638. decoder_out, _ = self.decoder(
  639. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  640. )
  641. # 2. Compute attention loss
  642. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  643. acc_att = th_accuracy(
  644. decoder_out.view(-1, self.vocab_size),
  645. ys_out_pad,
  646. ignore_label=self.ignore_id,
  647. )
  648. # Compute cer/wer using attention-decoder
  649. if self.training or self.error_calculator is None:
  650. cer_att, wer_att = None, None
  651. else:
  652. ys_hat = decoder_out.argmax(dim=-1)
  653. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  654. return loss_att, acc_att, cer_att, wer_att
  655. def _calc_att_predictor_loss(
  656. self,
  657. encoder_out: torch.Tensor,
  658. encoder_out_lens: torch.Tensor,
  659. ys_pad: torch.Tensor,
  660. ys_pad_lens: torch.Tensor,
  661. ):
  662. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  663. ys_in_lens = ys_pad_lens + 1
  664. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  665. device=encoder_out.device)[:, None, :]
  666. mask_chunk_predictor = None
  667. if self.encoder.overlap_chunk_cls is not None:
  668. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  669. device=encoder_out.device,
  670. batch_size=encoder_out.size(
  671. 0))
  672. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  673. batch_size=encoder_out.size(0))
  674. encoder_out = encoder_out * mask_shfit_chunk
  675. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  676. ys_out_pad,
  677. encoder_out_mask,
  678. ignore_id=self.ignore_id,
  679. mask_chunk_predictor=mask_chunk_predictor,
  680. target_label_length=ys_in_lens,
  681. )
  682. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  683. encoder_out_lens)
  684. scama_mask = None
  685. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  686. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  687. attention_chunk_center_bias = 0
  688. attention_chunk_size = encoder_chunk_size
  689. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  690. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  691. device=encoder_out.device,
  692. batch_size=encoder_out.size(
  693. 0))
  694. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  695. predictor_alignments=predictor_alignments,
  696. encoder_sequence_length=encoder_out_lens,
  697. chunk_size=1,
  698. encoder_chunk_size=encoder_chunk_size,
  699. attention_chunk_center_bias=attention_chunk_center_bias,
  700. attention_chunk_size=attention_chunk_size,
  701. attention_chunk_type=self.decoder_attention_chunk_type,
  702. step=None,
  703. predictor_mask_chunk_hopping=mask_chunk_predictor,
  704. decoder_att_look_back_factor=decoder_att_look_back_factor,
  705. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  706. target_length=ys_in_lens,
  707. is_training=self.training,
  708. )
  709. elif self.encoder.overlap_chunk_cls is not None:
  710. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  711. chunk_outs=None)
  712. # try:
  713. # 1. Forward decoder
  714. decoder_out, _ = self.decoder(
  715. encoder_out,
  716. encoder_out_lens,
  717. ys_in_pad,
  718. ys_in_lens,
  719. chunk_mask=scama_mask,
  720. pre_acoustic_embeds=pre_acoustic_embeds,
  721. )
  722. # 2. Compute attention loss
  723. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  724. acc_att = th_accuracy(
  725. decoder_out.view(-1, self.vocab_size),
  726. ys_out_pad,
  727. ignore_label=self.ignore_id,
  728. )
  729. # predictor loss
  730. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  731. # Compute cer/wer using attention-decoder
  732. if self.training or self.error_calculator is None:
  733. cer_att, wer_att = None, None
  734. else:
  735. ys_hat = decoder_out.argmax(dim=-1)
  736. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  737. return loss_att, acc_att, cer_att, wer_att, loss_pre
  738. def _calc_att_predictor_loss2(
  739. self,
  740. encoder_out: torch.Tensor,
  741. encoder_out_lens: torch.Tensor,
  742. ys_pad: torch.Tensor,
  743. ys_pad_lens: torch.Tensor,
  744. ):
  745. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  746. ys_in_lens = ys_pad_lens + 1
  747. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  748. device=encoder_out.device)[:, None, :]
  749. mask_chunk_predictor = None
  750. if self.encoder2.overlap_chunk_cls is not None:
  751. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  752. device=encoder_out.device,
  753. batch_size=encoder_out.size(
  754. 0))
  755. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  756. batch_size=encoder_out.size(0))
  757. encoder_out = encoder_out * mask_shfit_chunk
  758. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  759. ys_out_pad,
  760. encoder_out_mask,
  761. ignore_id=self.ignore_id,
  762. mask_chunk_predictor=mask_chunk_predictor,
  763. target_label_length=ys_in_lens,
  764. )
  765. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  766. encoder_out_lens)
  767. scama_mask = None
  768. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  769. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  770. attention_chunk_center_bias = 0
  771. attention_chunk_size = encoder_chunk_size
  772. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  773. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  774. device=encoder_out.device,
  775. batch_size=encoder_out.size(
  776. 0))
  777. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  778. predictor_alignments=predictor_alignments,
  779. encoder_sequence_length=encoder_out_lens,
  780. chunk_size=1,
  781. encoder_chunk_size=encoder_chunk_size,
  782. attention_chunk_center_bias=attention_chunk_center_bias,
  783. attention_chunk_size=attention_chunk_size,
  784. attention_chunk_type=self.decoder_attention_chunk_type2,
  785. step=None,
  786. predictor_mask_chunk_hopping=mask_chunk_predictor,
  787. decoder_att_look_back_factor=decoder_att_look_back_factor,
  788. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  789. target_length=ys_in_lens,
  790. is_training=self.training,
  791. )
  792. elif self.encoder2.overlap_chunk_cls is not None:
  793. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  794. chunk_outs=None)
  795. # try:
  796. # 1. Forward decoder
  797. decoder_out, _ = self.decoder2(
  798. encoder_out,
  799. encoder_out_lens,
  800. ys_in_pad,
  801. ys_in_lens,
  802. chunk_mask=scama_mask,
  803. pre_acoustic_embeds=pre_acoustic_embeds,
  804. )
  805. # 2. Compute attention loss
  806. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  807. acc_att = th_accuracy(
  808. decoder_out.view(-1, self.vocab_size),
  809. ys_out_pad,
  810. ignore_label=self.ignore_id,
  811. )
  812. # predictor loss
  813. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  814. # Compute cer/wer using attention-decoder
  815. if self.training or self.error_calculator is None:
  816. cer_att, wer_att = None, None
  817. else:
  818. ys_hat = decoder_out.argmax(dim=-1)
  819. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  820. return loss_att, acc_att, cer_att, wer_att, loss_pre
  821. def calc_predictor_mask(
  822. self,
  823. encoder_out: torch.Tensor,
  824. encoder_out_lens: torch.Tensor,
  825. ys_pad: torch.Tensor = None,
  826. ys_pad_lens: torch.Tensor = None,
  827. ):
  828. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  829. # ys_in_lens = ys_pad_lens + 1
  830. ys_out_pad, ys_in_lens = None, None
  831. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  832. device=encoder_out.device)[:, None, :]
  833. mask_chunk_predictor = None
  834. if self.encoder.overlap_chunk_cls is not None:
  835. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  836. device=encoder_out.device,
  837. batch_size=encoder_out.size(
  838. 0))
  839. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  840. batch_size=encoder_out.size(0))
  841. encoder_out = encoder_out * mask_shfit_chunk
  842. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  843. ys_out_pad,
  844. encoder_out_mask,
  845. ignore_id=self.ignore_id,
  846. mask_chunk_predictor=mask_chunk_predictor,
  847. target_label_length=ys_in_lens,
  848. )
  849. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  850. encoder_out_lens)
  851. scama_mask = None
  852. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  853. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  854. attention_chunk_center_bias = 0
  855. attention_chunk_size = encoder_chunk_size
  856. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  857. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  858. device=encoder_out.device,
  859. batch_size=encoder_out.size(
  860. 0))
  861. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  862. predictor_alignments=predictor_alignments,
  863. encoder_sequence_length=encoder_out_lens,
  864. chunk_size=1,
  865. encoder_chunk_size=encoder_chunk_size,
  866. attention_chunk_center_bias=attention_chunk_center_bias,
  867. attention_chunk_size=attention_chunk_size,
  868. attention_chunk_type=self.decoder_attention_chunk_type,
  869. step=None,
  870. predictor_mask_chunk_hopping=mask_chunk_predictor,
  871. decoder_att_look_back_factor=decoder_att_look_back_factor,
  872. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  873. target_length=ys_in_lens,
  874. is_training=self.training,
  875. )
  876. elif self.encoder.overlap_chunk_cls is not None:
  877. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  878. chunk_outs=None)
  879. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  880. def calc_predictor_mask2(
  881. self,
  882. encoder_out: torch.Tensor,
  883. encoder_out_lens: torch.Tensor,
  884. ys_pad: torch.Tensor = None,
  885. ys_pad_lens: torch.Tensor = None,
  886. ):
  887. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  888. # ys_in_lens = ys_pad_lens + 1
  889. ys_out_pad, ys_in_lens = None, None
  890. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  891. device=encoder_out.device)[:, None, :]
  892. mask_chunk_predictor = None
  893. if self.encoder2.overlap_chunk_cls is not None:
  894. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  895. device=encoder_out.device,
  896. batch_size=encoder_out.size(
  897. 0))
  898. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  899. batch_size=encoder_out.size(0))
  900. encoder_out = encoder_out * mask_shfit_chunk
  901. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  902. ys_out_pad,
  903. encoder_out_mask,
  904. ignore_id=self.ignore_id,
  905. mask_chunk_predictor=mask_chunk_predictor,
  906. target_label_length=ys_in_lens,
  907. )
  908. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  909. encoder_out_lens)
  910. scama_mask = None
  911. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  912. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  913. attention_chunk_center_bias = 0
  914. attention_chunk_size = encoder_chunk_size
  915. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  916. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  917. device=encoder_out.device,
  918. batch_size=encoder_out.size(
  919. 0))
  920. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  921. predictor_alignments=predictor_alignments,
  922. encoder_sequence_length=encoder_out_lens,
  923. chunk_size=1,
  924. encoder_chunk_size=encoder_chunk_size,
  925. attention_chunk_center_bias=attention_chunk_center_bias,
  926. attention_chunk_size=attention_chunk_size,
  927. attention_chunk_type=self.decoder_attention_chunk_type2,
  928. step=None,
  929. predictor_mask_chunk_hopping=mask_chunk_predictor,
  930. decoder_att_look_back_factor=decoder_att_look_back_factor,
  931. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  932. target_length=ys_in_lens,
  933. is_training=self.training,
  934. )
  935. elif self.encoder2.overlap_chunk_cls is not None:
  936. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  937. chunk_outs=None)
  938. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  939. def _calc_ctc_loss(
  940. self,
  941. encoder_out: torch.Tensor,
  942. encoder_out_lens: torch.Tensor,
  943. ys_pad: torch.Tensor,
  944. ys_pad_lens: torch.Tensor,
  945. ):
  946. # Calc CTC loss
  947. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  948. # Calc CER using CTC
  949. cer_ctc = None
  950. if not self.training and self.error_calculator is not None:
  951. ys_hat = self.ctc.argmax(encoder_out).data
  952. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  953. return loss_ctc, cer_ctc
  954. def _calc_ctc_loss2(
  955. self,
  956. encoder_out: torch.Tensor,
  957. encoder_out_lens: torch.Tensor,
  958. ys_pad: torch.Tensor,
  959. ys_pad_lens: torch.Tensor,
  960. ):
  961. # Calc CTC loss
  962. loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  963. # Calc CER using CTC
  964. cer_ctc = None
  965. if not self.training and self.error_calculator is not None:
  966. ys_hat = self.ctc2.argmax(encoder_out).data
  967. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  968. return loss_ctc, cer_ctc