e2e_uni_asr.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from funasr.models.e2e_asr_common import ErrorCalculator
  11. from funasr.modules.nets_utils import th_accuracy
  12. from funasr.modules.add_sos_eos import add_sos_eos
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.encoder.abs_encoder import AbsEncoder
  19. from funasr.models.frontend.abs_frontend import AbsFrontend
  20. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  21. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  22. from funasr.models.specaug.abs_specaug import AbsSpecAug
  23. from funasr.layers.abs_normalize import AbsNormalize
  24. from funasr.torch_utils.device_funcs import force_gatherable
  25. from funasr.models.base_model import FunASRModel
  26. from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
  27. from funasr.models.predictor.cif import mae_loss
  28. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  29. from torch.cuda.amp import autocast
  30. else:
  31. # Nothing to do if torch<1.6.0
  32. @contextmanager
  33. def autocast(enabled=True):
  34. yield
  35. class UniASR(FunASRModel):
  36. """
  37. Author: Speech Lab of DAMO Academy, Alibaba Group
  38. """
  39. def __init__(
  40. self,
  41. vocab_size: int,
  42. token_list: Union[Tuple[str, ...], List[str]],
  43. frontend: Optional[AbsFrontend],
  44. specaug: Optional[AbsSpecAug],
  45. normalize: Optional[AbsNormalize],
  46. encoder: AbsEncoder,
  47. decoder: AbsDecoder,
  48. ctc: CTC,
  49. ctc_weight: float = 0.5,
  50. interctc_weight: float = 0.0,
  51. ignore_id: int = -1,
  52. lsm_weight: float = 0.0,
  53. length_normalized_loss: bool = False,
  54. report_cer: bool = True,
  55. report_wer: bool = True,
  56. sym_space: str = "<space>",
  57. sym_blank: str = "<blank>",
  58. extract_feats_in_collect_stats: bool = True,
  59. predictor=None,
  60. predictor_weight: float = 0.0,
  61. decoder_attention_chunk_type: str = 'chunk',
  62. encoder2: AbsEncoder = None,
  63. decoder2: AbsDecoder = None,
  64. ctc2: CTC = None,
  65. ctc_weight2: float = 0.5,
  66. interctc_weight2: float = 0.0,
  67. predictor2=None,
  68. predictor_weight2: float = 0.0,
  69. decoder_attention_chunk_type2: str = 'chunk',
  70. stride_conv=None,
  71. loss_weight_model1: float = 0.5,
  72. enable_maas_finetune: bool = False,
  73. freeze_encoder2: bool = False,
  74. preencoder: Optional[AbsPreEncoder] = None,
  75. postencoder: Optional[AbsPostEncoder] = None,
  76. encoder1_encoder2_joint_training: bool = True,
  77. ):
  78. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  79. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  80. super().__init__()
  81. self.blank_id = 0
  82. self.sos = 1
  83. self.eos = 2
  84. self.vocab_size = vocab_size
  85. self.ignore_id = ignore_id
  86. self.ctc_weight = ctc_weight
  87. self.interctc_weight = interctc_weight
  88. self.token_list = token_list.copy()
  89. self.frontend = frontend
  90. self.specaug = specaug
  91. self.normalize = normalize
  92. self.preencoder = preencoder
  93. self.postencoder = postencoder
  94. self.encoder = encoder
  95. if not hasattr(self.encoder, "interctc_use_conditioning"):
  96. self.encoder.interctc_use_conditioning = False
  97. if self.encoder.interctc_use_conditioning:
  98. self.encoder.conditioning_layer = torch.nn.Linear(
  99. vocab_size, self.encoder.output_size()
  100. )
  101. self.error_calculator = None
  102. # we set self.decoder = None in the CTC mode since
  103. # self.decoder parameters were never used and PyTorch complained
  104. # and threw an Exception in the multi-GPU experiment.
  105. # thanks Jeff Farris for pointing out the issue.
  106. if ctc_weight == 1.0:
  107. self.decoder = None
  108. else:
  109. self.decoder = decoder
  110. self.criterion_att = LabelSmoothingLoss(
  111. size=vocab_size,
  112. padding_idx=ignore_id,
  113. smoothing=lsm_weight,
  114. normalize_length=length_normalized_loss,
  115. )
  116. if report_cer or report_wer:
  117. self.error_calculator = ErrorCalculator(
  118. token_list, sym_space, sym_blank, report_cer, report_wer
  119. )
  120. if ctc_weight == 0.0:
  121. self.ctc = None
  122. else:
  123. self.ctc = ctc
  124. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  125. self.predictor = predictor
  126. self.predictor_weight = predictor_weight
  127. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  128. self.step_cur = 0
  129. if self.encoder.overlap_chunk_cls is not None:
  130. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  131. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  132. self.decoder_attention_chunk_type = decoder_attention_chunk_type
  133. self.encoder2 = encoder2
  134. self.decoder2 = decoder2
  135. self.ctc_weight2 = ctc_weight2
  136. if ctc_weight2 == 0.0:
  137. self.ctc2 = None
  138. else:
  139. self.ctc2 = ctc2
  140. self.interctc_weight2 = interctc_weight2
  141. self.predictor2 = predictor2
  142. self.predictor_weight2 = predictor_weight2
  143. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  144. self.stride_conv = stride_conv
  145. self.loss_weight_model1 = loss_weight_model1
  146. if self.encoder2.overlap_chunk_cls is not None:
  147. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  148. self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
  149. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  150. self.enable_maas_finetune = enable_maas_finetune
  151. self.freeze_encoder2 = freeze_encoder2
  152. self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
  153. def forward(
  154. self,
  155. speech: torch.Tensor,
  156. speech_lengths: torch.Tensor,
  157. text: torch.Tensor,
  158. text_lengths: torch.Tensor,
  159. decoding_ind: int = None,
  160. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  161. """Frontend + Encoder + Decoder + Calc loss
  162. Args:
  163. speech: (Batch, Length, ...)
  164. speech_lengths: (Batch, )
  165. text: (Batch, Length)
  166. text_lengths: (Batch,)
  167. """
  168. assert text_lengths.dim() == 1, text_lengths.shape
  169. # Check that batch_size is unified
  170. assert (
  171. speech.shape[0]
  172. == speech_lengths.shape[0]
  173. == text.shape[0]
  174. == text_lengths.shape[0]
  175. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  176. batch_size = speech.shape[0]
  177. # for data-parallel
  178. text = text[:, : text_lengths.max()]
  179. speech = speech[:, :speech_lengths.max()]
  180. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  181. # 1. Encoder
  182. if self.enable_maas_finetune:
  183. with torch.no_grad():
  184. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  185. else:
  186. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  187. intermediate_outs = None
  188. if isinstance(encoder_out, tuple):
  189. intermediate_outs = encoder_out[1]
  190. encoder_out = encoder_out[0]
  191. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  192. loss_ctc, cer_ctc = None, None
  193. stats = dict()
  194. loss_pre = None
  195. loss, loss1, loss2 = 0.0, 0.0, 0.0
  196. if self.loss_weight_model1 > 0.0:
  197. ## model1
  198. # 1. CTC branch
  199. if self.enable_maas_finetune:
  200. with torch.no_grad():
  201. if self.ctc_weight != 0.0:
  202. if self.encoder.overlap_chunk_cls is not None:
  203. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  204. encoder_out_lens,
  205. chunk_outs=None)
  206. loss_ctc, cer_ctc = self._calc_ctc_loss(
  207. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  208. )
  209. # Collect CTC branch stats
  210. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  211. stats["cer_ctc"] = cer_ctc
  212. # Intermediate CTC (optional)
  213. loss_interctc = 0.0
  214. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  215. for layer_idx, intermediate_out in intermediate_outs:
  216. # we assume intermediate_out has the same length & padding
  217. # as those of encoder_out
  218. if self.encoder.overlap_chunk_cls is not None:
  219. encoder_out_ctc, encoder_out_lens_ctc = \
  220. self.encoder.overlap_chunk_cls.remove_chunk(
  221. intermediate_out,
  222. encoder_out_lens,
  223. chunk_outs=None)
  224. loss_ic, cer_ic = self._calc_ctc_loss(
  225. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  226. )
  227. loss_interctc = loss_interctc + loss_ic
  228. # Collect Intermedaite CTC stats
  229. stats["loss_interctc_layer{}".format(layer_idx)] = (
  230. loss_ic.detach() if loss_ic is not None else None
  231. )
  232. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  233. loss_interctc = loss_interctc / len(intermediate_outs)
  234. # calculate whole encoder loss
  235. loss_ctc = (
  236. 1 - self.interctc_weight
  237. ) * loss_ctc + self.interctc_weight * loss_interctc
  238. # 2b. Attention decoder branch
  239. if self.ctc_weight != 1.0:
  240. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  241. encoder_out, encoder_out_lens, text, text_lengths
  242. )
  243. # 3. CTC-Att loss definition
  244. if self.ctc_weight == 0.0:
  245. loss = loss_att + loss_pre * self.predictor_weight
  246. elif self.ctc_weight == 1.0:
  247. loss = loss_ctc
  248. else:
  249. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  250. # Collect Attn branch stats
  251. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  252. stats["acc"] = acc_att
  253. stats["cer"] = cer_att
  254. stats["wer"] = wer_att
  255. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  256. else:
  257. if self.ctc_weight != 0.0:
  258. if self.encoder.overlap_chunk_cls is not None:
  259. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  260. encoder_out_lens,
  261. chunk_outs=None)
  262. loss_ctc, cer_ctc = self._calc_ctc_loss(
  263. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  264. )
  265. # Collect CTC branch stats
  266. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  267. stats["cer_ctc"] = cer_ctc
  268. # Intermediate CTC (optional)
  269. loss_interctc = 0.0
  270. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  271. for layer_idx, intermediate_out in intermediate_outs:
  272. # we assume intermediate_out has the same length & padding
  273. # as those of encoder_out
  274. if self.encoder.overlap_chunk_cls is not None:
  275. encoder_out_ctc, encoder_out_lens_ctc = \
  276. self.encoder.overlap_chunk_cls.remove_chunk(
  277. intermediate_out,
  278. encoder_out_lens,
  279. chunk_outs=None)
  280. loss_ic, cer_ic = self._calc_ctc_loss(
  281. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  282. )
  283. loss_interctc = loss_interctc + loss_ic
  284. # Collect Intermedaite CTC stats
  285. stats["loss_interctc_layer{}".format(layer_idx)] = (
  286. loss_ic.detach() if loss_ic is not None else None
  287. )
  288. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  289. loss_interctc = loss_interctc / len(intermediate_outs)
  290. # calculate whole encoder loss
  291. loss_ctc = (
  292. 1 - self.interctc_weight
  293. ) * loss_ctc + self.interctc_weight * loss_interctc
  294. # 2b. Attention decoder branch
  295. if self.ctc_weight != 1.0:
  296. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  297. encoder_out, encoder_out_lens, text, text_lengths
  298. )
  299. # 3. CTC-Att loss definition
  300. if self.ctc_weight == 0.0:
  301. loss = loss_att + loss_pre * self.predictor_weight
  302. elif self.ctc_weight == 1.0:
  303. loss = loss_ctc
  304. else:
  305. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  306. # Collect Attn branch stats
  307. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  308. stats["acc"] = acc_att
  309. stats["cer"] = cer_att
  310. stats["wer"] = wer_att
  311. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  312. loss1 = loss
  313. if self.loss_weight_model1 < 1.0:
  314. ## model2
  315. # encoder2
  316. if self.freeze_encoder2:
  317. with torch.no_grad():
  318. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  319. else:
  320. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  321. intermediate_outs = None
  322. if isinstance(encoder_out, tuple):
  323. intermediate_outs = encoder_out[1]
  324. encoder_out = encoder_out[0]
  325. # CTC2
  326. if self.ctc_weight2 != 0.0:
  327. if self.encoder2.overlap_chunk_cls is not None:
  328. encoder_out_ctc, encoder_out_lens_ctc = \
  329. self.encoder2.overlap_chunk_cls.remove_chunk(
  330. encoder_out,
  331. encoder_out_lens,
  332. chunk_outs=None,
  333. )
  334. loss_ctc, cer_ctc = self._calc_ctc_loss2(
  335. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  336. )
  337. # Collect CTC branch stats
  338. stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
  339. stats["cer_ctc2"] = cer_ctc
  340. # Intermediate CTC (optional)
  341. loss_interctc = 0.0
  342. if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
  343. for layer_idx, intermediate_out in intermediate_outs:
  344. # we assume intermediate_out has the same length & padding
  345. # as those of encoder_out
  346. if self.encoder2.overlap_chunk_cls is not None:
  347. encoder_out_ctc, encoder_out_lens_ctc = \
  348. self.encoder2.overlap_chunk_cls.remove_chunk(
  349. intermediate_out,
  350. encoder_out_lens,
  351. chunk_outs=None)
  352. loss_ic, cer_ic = self._calc_ctc_loss2(
  353. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  354. )
  355. loss_interctc = loss_interctc + loss_ic
  356. # Collect Intermedaite CTC stats
  357. stats["loss_interctc_layer{}2".format(layer_idx)] = (
  358. loss_ic.detach() if loss_ic is not None else None
  359. )
  360. stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
  361. loss_interctc = loss_interctc / len(intermediate_outs)
  362. # calculate whole encoder loss
  363. loss_ctc = (
  364. 1 - self.interctc_weight2
  365. ) * loss_ctc + self.interctc_weight2 * loss_interctc
  366. # 2b. Attention decoder branch
  367. if self.ctc_weight2 != 1.0:
  368. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
  369. encoder_out, encoder_out_lens, text, text_lengths
  370. )
  371. # 3. CTC-Att loss definition
  372. if self.ctc_weight2 == 0.0:
  373. loss = loss_att + loss_pre * self.predictor_weight2
  374. elif self.ctc_weight2 == 1.0:
  375. loss = loss_ctc
  376. else:
  377. loss = self.ctc_weight2 * loss_ctc + (
  378. 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
  379. # Collect Attn branch stats
  380. stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
  381. stats["acc2"] = acc_att
  382. stats["cer2"] = cer_att
  383. stats["wer2"] = wer_att
  384. stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
  385. loss2 = loss
  386. loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
  387. stats["loss1"] = torch.clone(loss1.detach())
  388. stats["loss2"] = torch.clone(loss2.detach())
  389. stats["loss"] = torch.clone(loss.detach())
  390. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  391. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  392. return loss, stats, weight
  393. def collect_feats(
  394. self,
  395. speech: torch.Tensor,
  396. speech_lengths: torch.Tensor,
  397. text: torch.Tensor,
  398. text_lengths: torch.Tensor,
  399. ) -> Dict[str, torch.Tensor]:
  400. if self.extract_feats_in_collect_stats:
  401. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  402. else:
  403. # Generate dummy stats if extract_feats_in_collect_stats is False
  404. logging.warning(
  405. "Generating dummy stats for feats and feats_lengths, "
  406. "because encoder_conf.extract_feats_in_collect_stats is "
  407. f"{self.extract_feats_in_collect_stats}"
  408. )
  409. feats, feats_lengths = speech, speech_lengths
  410. return {"feats": feats, "feats_lengths": feats_lengths}
  411. def encode(
  412. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  413. ) -> Tuple[torch.Tensor, torch.Tensor]:
  414. """Frontend + Encoder. Note that this method is used by asr_inference.py
  415. Args:
  416. speech: (Batch, Length, ...)
  417. speech_lengths: (Batch, )
  418. """
  419. with autocast(False):
  420. # 1. Extract feats
  421. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  422. # 2. Data augmentation
  423. if self.specaug is not None and self.training:
  424. feats, feats_lengths = self.specaug(feats, feats_lengths)
  425. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  426. if self.normalize is not None:
  427. feats, feats_lengths = self.normalize(feats, feats_lengths)
  428. speech_raw = feats.clone().to(feats.device)
  429. # Pre-encoder, e.g. used for raw input data
  430. if self.preencoder is not None:
  431. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  432. # 4. Forward encoder
  433. # feats: (Batch, Length, Dim)
  434. # -> encoder_out: (Batch, Length2, Dim2)
  435. if self.encoder.interctc_use_conditioning:
  436. encoder_out, encoder_out_lens, _ = self.encoder(
  437. feats, feats_lengths, ctc=self.ctc, ind=ind
  438. )
  439. else:
  440. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  441. intermediate_outs = None
  442. if isinstance(encoder_out, tuple):
  443. intermediate_outs = encoder_out[1]
  444. encoder_out = encoder_out[0]
  445. # Post-encoder, e.g. NLU
  446. if self.postencoder is not None:
  447. encoder_out, encoder_out_lens = self.postencoder(
  448. encoder_out, encoder_out_lens
  449. )
  450. assert encoder_out.size(0) == speech.size(0), (
  451. encoder_out.size(),
  452. speech.size(0),
  453. )
  454. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  455. encoder_out.size(),
  456. encoder_out_lens.max(),
  457. )
  458. if intermediate_outs is not None:
  459. return (encoder_out, intermediate_outs), encoder_out_lens
  460. return speech_raw, encoder_out, encoder_out_lens
  461. def encode2(
  462. self,
  463. encoder_out: torch.Tensor,
  464. encoder_out_lens: torch.Tensor,
  465. speech: torch.Tensor,
  466. speech_lengths: torch.Tensor,
  467. ind: int = 0,
  468. ) -> Tuple[torch.Tensor, torch.Tensor]:
  469. """Frontend + Encoder. Note that this method is used by asr_inference.py
  470. Args:
  471. speech: (Batch, Length, ...)
  472. speech_lengths: (Batch, )
  473. """
  474. # with autocast(False):
  475. # # 1. Extract feats
  476. # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  477. #
  478. # # 2. Data augmentation
  479. # if self.specaug is not None and self.training:
  480. # feats, feats_lengths = self.specaug(feats, feats_lengths)
  481. #
  482. # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  483. # if self.normalize is not None:
  484. # feats, feats_lengths = self.normalize(feats, feats_lengths)
  485. # Pre-encoder, e.g. used for raw input data
  486. # if self.preencoder is not None:
  487. # feats, feats_lengths = self.preencoder(feats, feats_lengths)
  488. encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
  489. encoder_out,
  490. encoder_out_lens,
  491. chunk_outs=None,
  492. )
  493. # residual_input
  494. encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
  495. encoder_out_lens = encoder_out_lens_rm
  496. if self.stride_conv is not None:
  497. speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
  498. if not self.encoder1_encoder2_joint_training:
  499. speech = speech.detach()
  500. speech_lengths = speech_lengths.detach()
  501. # 4. Forward encoder
  502. # feats: (Batch, Length, Dim)
  503. # -> encoder_out: (Batch, Length2, Dim2)
  504. if self.encoder2.interctc_use_conditioning:
  505. encoder_out, encoder_out_lens, _ = self.encoder2(
  506. speech, speech_lengths, ctc=self.ctc2, ind=ind
  507. )
  508. else:
  509. encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
  510. intermediate_outs = None
  511. if isinstance(encoder_out, tuple):
  512. intermediate_outs = encoder_out[1]
  513. encoder_out = encoder_out[0]
  514. # # Post-encoder, e.g. NLU
  515. # if self.postencoder is not None:
  516. # encoder_out, encoder_out_lens = self.postencoder(
  517. # encoder_out, encoder_out_lens
  518. # )
  519. assert encoder_out.size(0) == speech.size(0), (
  520. encoder_out.size(),
  521. speech.size(0),
  522. )
  523. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  524. encoder_out.size(),
  525. encoder_out_lens.max(),
  526. )
  527. if intermediate_outs is not None:
  528. return (encoder_out, intermediate_outs), encoder_out_lens
  529. return encoder_out, encoder_out_lens
  530. def _extract_feats(
  531. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  532. ) -> Tuple[torch.Tensor, torch.Tensor]:
  533. assert speech_lengths.dim() == 1, speech_lengths.shape
  534. # for data-parallel
  535. speech = speech[:, : speech_lengths.max()]
  536. if self.frontend is not None:
  537. # Frontend
  538. # e.g. STFT and Feature extract
  539. # data_loader may send time-domain signal in this case
  540. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  541. feats, feats_lengths = self.frontend(speech, speech_lengths)
  542. else:
  543. # No frontend and no feature extract
  544. feats, feats_lengths = speech, speech_lengths
  545. return feats, feats_lengths
  546. def nll(
  547. self,
  548. encoder_out: torch.Tensor,
  549. encoder_out_lens: torch.Tensor,
  550. ys_pad: torch.Tensor,
  551. ys_pad_lens: torch.Tensor,
  552. ) -> torch.Tensor:
  553. """Compute negative log likelihood(nll) from transformer-decoder
  554. Normally, this function is called in batchify_nll.
  555. Args:
  556. encoder_out: (Batch, Length, Dim)
  557. encoder_out_lens: (Batch,)
  558. ys_pad: (Batch, Length)
  559. ys_pad_lens: (Batch,)
  560. """
  561. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  562. ys_in_lens = ys_pad_lens + 1
  563. # 1. Forward decoder
  564. decoder_out, _ = self.decoder(
  565. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  566. ) # [batch, seqlen, dim]
  567. batch_size = decoder_out.size(0)
  568. decoder_num_class = decoder_out.size(2)
  569. # nll: negative log-likelihood
  570. nll = torch.nn.functional.cross_entropy(
  571. decoder_out.view(-1, decoder_num_class),
  572. ys_out_pad.view(-1),
  573. ignore_index=self.ignore_id,
  574. reduction="none",
  575. )
  576. nll = nll.view(batch_size, -1)
  577. nll = nll.sum(dim=1)
  578. assert nll.size(0) == batch_size
  579. return nll
  580. def batchify_nll(
  581. self,
  582. encoder_out: torch.Tensor,
  583. encoder_out_lens: torch.Tensor,
  584. ys_pad: torch.Tensor,
  585. ys_pad_lens: torch.Tensor,
  586. batch_size: int = 100,
  587. ):
  588. """Compute negative log likelihood(nll) from transformer-decoder
  589. To avoid OOM, this fuction seperate the input into batches.
  590. Then call nll for each batch and combine and return results.
  591. Args:
  592. encoder_out: (Batch, Length, Dim)
  593. encoder_out_lens: (Batch,)
  594. ys_pad: (Batch, Length)
  595. ys_pad_lens: (Batch,)
  596. batch_size: int, samples each batch contain when computing nll,
  597. you may change this to avoid OOM or increase
  598. GPU memory usage
  599. """
  600. total_num = encoder_out.size(0)
  601. if total_num <= batch_size:
  602. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  603. else:
  604. nll = []
  605. start_idx = 0
  606. while True:
  607. end_idx = min(start_idx + batch_size, total_num)
  608. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  609. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  610. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  611. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  612. batch_nll = self.nll(
  613. batch_encoder_out,
  614. batch_encoder_out_lens,
  615. batch_ys_pad,
  616. batch_ys_pad_lens,
  617. )
  618. nll.append(batch_nll)
  619. start_idx = end_idx
  620. if start_idx == total_num:
  621. break
  622. nll = torch.cat(nll)
  623. assert nll.size(0) == total_num
  624. return nll
  625. def _calc_att_loss(
  626. self,
  627. encoder_out: torch.Tensor,
  628. encoder_out_lens: torch.Tensor,
  629. ys_pad: torch.Tensor,
  630. ys_pad_lens: torch.Tensor,
  631. ):
  632. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  633. ys_in_lens = ys_pad_lens + 1
  634. # 1. Forward decoder
  635. decoder_out, _ = self.decoder(
  636. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  637. )
  638. # 2. Compute attention loss
  639. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  640. acc_att = th_accuracy(
  641. decoder_out.view(-1, self.vocab_size),
  642. ys_out_pad,
  643. ignore_label=self.ignore_id,
  644. )
  645. # Compute cer/wer using attention-decoder
  646. if self.training or self.error_calculator is None:
  647. cer_att, wer_att = None, None
  648. else:
  649. ys_hat = decoder_out.argmax(dim=-1)
  650. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  651. return loss_att, acc_att, cer_att, wer_att
  652. def _calc_att_predictor_loss(
  653. self,
  654. encoder_out: torch.Tensor,
  655. encoder_out_lens: torch.Tensor,
  656. ys_pad: torch.Tensor,
  657. ys_pad_lens: torch.Tensor,
  658. ):
  659. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  660. ys_in_lens = ys_pad_lens + 1
  661. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  662. device=encoder_out.device)[:, None, :]
  663. mask_chunk_predictor = None
  664. if self.encoder.overlap_chunk_cls is not None:
  665. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  666. device=encoder_out.device,
  667. batch_size=encoder_out.size(
  668. 0))
  669. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  670. batch_size=encoder_out.size(0))
  671. encoder_out = encoder_out * mask_shfit_chunk
  672. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  673. ys_out_pad,
  674. encoder_out_mask,
  675. ignore_id=self.ignore_id,
  676. mask_chunk_predictor=mask_chunk_predictor,
  677. target_label_length=ys_in_lens,
  678. )
  679. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  680. encoder_out_lens)
  681. scama_mask = None
  682. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  683. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  684. attention_chunk_center_bias = 0
  685. attention_chunk_size = encoder_chunk_size
  686. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  687. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  688. device=encoder_out.device,
  689. batch_size=encoder_out.size(
  690. 0))
  691. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  692. predictor_alignments=predictor_alignments,
  693. encoder_sequence_length=encoder_out_lens,
  694. chunk_size=1,
  695. encoder_chunk_size=encoder_chunk_size,
  696. attention_chunk_center_bias=attention_chunk_center_bias,
  697. attention_chunk_size=attention_chunk_size,
  698. attention_chunk_type=self.decoder_attention_chunk_type,
  699. step=None,
  700. predictor_mask_chunk_hopping=mask_chunk_predictor,
  701. decoder_att_look_back_factor=decoder_att_look_back_factor,
  702. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  703. target_length=ys_in_lens,
  704. is_training=self.training,
  705. )
  706. elif self.encoder.overlap_chunk_cls is not None:
  707. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  708. chunk_outs=None)
  709. # try:
  710. # 1. Forward decoder
  711. decoder_out, _ = self.decoder(
  712. encoder_out,
  713. encoder_out_lens,
  714. ys_in_pad,
  715. ys_in_lens,
  716. chunk_mask=scama_mask,
  717. pre_acoustic_embeds=pre_acoustic_embeds,
  718. )
  719. # 2. Compute attention loss
  720. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  721. acc_att = th_accuracy(
  722. decoder_out.view(-1, self.vocab_size),
  723. ys_out_pad,
  724. ignore_label=self.ignore_id,
  725. )
  726. # predictor loss
  727. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  728. # Compute cer/wer using attention-decoder
  729. if self.training or self.error_calculator is None:
  730. cer_att, wer_att = None, None
  731. else:
  732. ys_hat = decoder_out.argmax(dim=-1)
  733. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  734. return loss_att, acc_att, cer_att, wer_att, loss_pre
  735. def _calc_att_predictor_loss2(
  736. self,
  737. encoder_out: torch.Tensor,
  738. encoder_out_lens: torch.Tensor,
  739. ys_pad: torch.Tensor,
  740. ys_pad_lens: torch.Tensor,
  741. ):
  742. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  743. ys_in_lens = ys_pad_lens + 1
  744. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  745. device=encoder_out.device)[:, None, :]
  746. mask_chunk_predictor = None
  747. if self.encoder2.overlap_chunk_cls is not None:
  748. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  749. device=encoder_out.device,
  750. batch_size=encoder_out.size(
  751. 0))
  752. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  753. batch_size=encoder_out.size(0))
  754. encoder_out = encoder_out * mask_shfit_chunk
  755. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  756. ys_out_pad,
  757. encoder_out_mask,
  758. ignore_id=self.ignore_id,
  759. mask_chunk_predictor=mask_chunk_predictor,
  760. target_label_length=ys_in_lens,
  761. )
  762. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  763. encoder_out_lens)
  764. scama_mask = None
  765. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  766. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  767. attention_chunk_center_bias = 0
  768. attention_chunk_size = encoder_chunk_size
  769. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  770. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  771. device=encoder_out.device,
  772. batch_size=encoder_out.size(
  773. 0))
  774. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  775. predictor_alignments=predictor_alignments,
  776. encoder_sequence_length=encoder_out_lens,
  777. chunk_size=1,
  778. encoder_chunk_size=encoder_chunk_size,
  779. attention_chunk_center_bias=attention_chunk_center_bias,
  780. attention_chunk_size=attention_chunk_size,
  781. attention_chunk_type=self.decoder_attention_chunk_type2,
  782. step=None,
  783. predictor_mask_chunk_hopping=mask_chunk_predictor,
  784. decoder_att_look_back_factor=decoder_att_look_back_factor,
  785. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  786. target_length=ys_in_lens,
  787. is_training=self.training,
  788. )
  789. elif self.encoder2.overlap_chunk_cls is not None:
  790. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  791. chunk_outs=None)
  792. # try:
  793. # 1. Forward decoder
  794. decoder_out, _ = self.decoder2(
  795. encoder_out,
  796. encoder_out_lens,
  797. ys_in_pad,
  798. ys_in_lens,
  799. chunk_mask=scama_mask,
  800. pre_acoustic_embeds=pre_acoustic_embeds,
  801. )
  802. # 2. Compute attention loss
  803. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  804. acc_att = th_accuracy(
  805. decoder_out.view(-1, self.vocab_size),
  806. ys_out_pad,
  807. ignore_label=self.ignore_id,
  808. )
  809. # predictor loss
  810. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  811. # Compute cer/wer using attention-decoder
  812. if self.training or self.error_calculator is None:
  813. cer_att, wer_att = None, None
  814. else:
  815. ys_hat = decoder_out.argmax(dim=-1)
  816. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  817. return loss_att, acc_att, cer_att, wer_att, loss_pre
  818. def calc_predictor_mask(
  819. self,
  820. encoder_out: torch.Tensor,
  821. encoder_out_lens: torch.Tensor,
  822. ys_pad: torch.Tensor = None,
  823. ys_pad_lens: torch.Tensor = None,
  824. ):
  825. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  826. # ys_in_lens = ys_pad_lens + 1
  827. ys_out_pad, ys_in_lens = None, None
  828. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  829. device=encoder_out.device)[:, None, :]
  830. mask_chunk_predictor = None
  831. if self.encoder.overlap_chunk_cls is not None:
  832. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  833. device=encoder_out.device,
  834. batch_size=encoder_out.size(
  835. 0))
  836. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  837. batch_size=encoder_out.size(0))
  838. encoder_out = encoder_out * mask_shfit_chunk
  839. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  840. ys_out_pad,
  841. encoder_out_mask,
  842. ignore_id=self.ignore_id,
  843. mask_chunk_predictor=mask_chunk_predictor,
  844. target_label_length=ys_in_lens,
  845. )
  846. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  847. encoder_out_lens)
  848. scama_mask = None
  849. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  850. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  851. attention_chunk_center_bias = 0
  852. attention_chunk_size = encoder_chunk_size
  853. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  854. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  855. device=encoder_out.device,
  856. batch_size=encoder_out.size(
  857. 0))
  858. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  859. predictor_alignments=predictor_alignments,
  860. encoder_sequence_length=encoder_out_lens,
  861. chunk_size=1,
  862. encoder_chunk_size=encoder_chunk_size,
  863. attention_chunk_center_bias=attention_chunk_center_bias,
  864. attention_chunk_size=attention_chunk_size,
  865. attention_chunk_type=self.decoder_attention_chunk_type,
  866. step=None,
  867. predictor_mask_chunk_hopping=mask_chunk_predictor,
  868. decoder_att_look_back_factor=decoder_att_look_back_factor,
  869. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  870. target_length=ys_in_lens,
  871. is_training=self.training,
  872. )
  873. elif self.encoder.overlap_chunk_cls is not None:
  874. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  875. chunk_outs=None)
  876. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  877. def calc_predictor_mask2(
  878. self,
  879. encoder_out: torch.Tensor,
  880. encoder_out_lens: torch.Tensor,
  881. ys_pad: torch.Tensor = None,
  882. ys_pad_lens: torch.Tensor = None,
  883. ):
  884. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  885. # ys_in_lens = ys_pad_lens + 1
  886. ys_out_pad, ys_in_lens = None, None
  887. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  888. device=encoder_out.device)[:, None, :]
  889. mask_chunk_predictor = None
  890. if self.encoder2.overlap_chunk_cls is not None:
  891. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  892. device=encoder_out.device,
  893. batch_size=encoder_out.size(
  894. 0))
  895. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  896. batch_size=encoder_out.size(0))
  897. encoder_out = encoder_out * mask_shfit_chunk
  898. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  899. ys_out_pad,
  900. encoder_out_mask,
  901. ignore_id=self.ignore_id,
  902. mask_chunk_predictor=mask_chunk_predictor,
  903. target_label_length=ys_in_lens,
  904. )
  905. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  906. encoder_out_lens)
  907. scama_mask = None
  908. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  909. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  910. attention_chunk_center_bias = 0
  911. attention_chunk_size = encoder_chunk_size
  912. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  913. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  914. device=encoder_out.device,
  915. batch_size=encoder_out.size(
  916. 0))
  917. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  918. predictor_alignments=predictor_alignments,
  919. encoder_sequence_length=encoder_out_lens,
  920. chunk_size=1,
  921. encoder_chunk_size=encoder_chunk_size,
  922. attention_chunk_center_bias=attention_chunk_center_bias,
  923. attention_chunk_size=attention_chunk_size,
  924. attention_chunk_type=self.decoder_attention_chunk_type2,
  925. step=None,
  926. predictor_mask_chunk_hopping=mask_chunk_predictor,
  927. decoder_att_look_back_factor=decoder_att_look_back_factor,
  928. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  929. target_length=ys_in_lens,
  930. is_training=self.training,
  931. )
  932. elif self.encoder2.overlap_chunk_cls is not None:
  933. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  934. chunk_outs=None)
  935. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  936. def _calc_ctc_loss(
  937. self,
  938. encoder_out: torch.Tensor,
  939. encoder_out_lens: torch.Tensor,
  940. ys_pad: torch.Tensor,
  941. ys_pad_lens: torch.Tensor,
  942. ):
  943. # Calc CTC loss
  944. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  945. # Calc CER using CTC
  946. cer_ctc = None
  947. if not self.training and self.error_calculator is not None:
  948. ys_hat = self.ctc.argmax(encoder_out).data
  949. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  950. return loss_ctc, cer_ctc
  951. def _calc_ctc_loss2(
  952. self,
  953. encoder_out: torch.Tensor,
  954. encoder_out_lens: torch.Tensor,
  955. ys_pad: torch.Tensor,
  956. ys_pad_lens: torch.Tensor,
  957. ):
  958. # Calc CTC loss
  959. loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  960. # Calc CER using CTC
  961. cer_ctc = None
  962. if not self.training and self.error_calculator is not None:
  963. ys_hat = self.ctc2.argmax(encoder_out).data
  964. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  965. return loss_ctc, cer_ctc