e2e_uni_asr.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. from typeguard import check_argument_types
  11. from funasr.models.e2e_asr_common import ErrorCalculator
  12. from funasr.modules.nets_utils import th_accuracy
  13. from funasr.modules.add_sos_eos import add_sos_eos
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  23. from funasr.models.specaug.abs_specaug import AbsSpecAug
  24. from funasr.layers.abs_normalize import AbsNormalize
  25. from funasr.torch_utils.device_funcs import force_gatherable
  26. from funasr.models.base_model import FunASRModel
  27. from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
  28. from funasr.models.predictor.cif import mae_loss
  29. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  30. from torch.cuda.amp import autocast
  31. else:
  32. # Nothing to do if torch<1.6.0
  33. @contextmanager
  34. def autocast(enabled=True):
  35. yield
  36. class UniASR(FunASRModel):
  37. """
  38. Author: Speech Lab of DAMO Academy, Alibaba Group
  39. """
  40. def __init__(
  41. self,
  42. vocab_size: int,
  43. token_list: Union[Tuple[str, ...], List[str]],
  44. frontend: Optional[AbsFrontend],
  45. specaug: Optional[AbsSpecAug],
  46. normalize: Optional[AbsNormalize],
  47. preencoder: Optional[AbsPreEncoder],
  48. encoder: AbsEncoder,
  49. postencoder: Optional[AbsPostEncoder],
  50. decoder: AbsDecoder,
  51. ctc: CTC,
  52. ctc_weight: float = 0.5,
  53. interctc_weight: float = 0.0,
  54. ignore_id: int = -1,
  55. lsm_weight: float = 0.0,
  56. length_normalized_loss: bool = False,
  57. report_cer: bool = True,
  58. report_wer: bool = True,
  59. sym_space: str = "<space>",
  60. sym_blank: str = "<blank>",
  61. extract_feats_in_collect_stats: bool = True,
  62. predictor=None,
  63. predictor_weight: float = 0.0,
  64. decoder_attention_chunk_type: str = 'chunk',
  65. encoder2: AbsEncoder = None,
  66. decoder2: AbsDecoder = None,
  67. ctc2: CTC = None,
  68. ctc_weight2: float = 0.5,
  69. interctc_weight2: float = 0.0,
  70. predictor2=None,
  71. predictor_weight2: float = 0.0,
  72. decoder_attention_chunk_type2: str = 'chunk',
  73. stride_conv=None,
  74. loss_weight_model1: float = 0.5,
  75. enable_maas_finetune: bool = False,
  76. freeze_encoder2: bool = False,
  77. encoder1_encoder2_joint_training: bool = True,
  78. ):
  79. assert check_argument_types()
  80. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  81. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  82. super().__init__()
  83. self.blank_id = 0
  84. self.sos = 1
  85. self.eos = 2
  86. self.vocab_size = vocab_size
  87. self.ignore_id = ignore_id
  88. self.ctc_weight = ctc_weight
  89. self.interctc_weight = interctc_weight
  90. self.token_list = token_list.copy()
  91. self.frontend = frontend
  92. self.specaug = specaug
  93. self.normalize = normalize
  94. self.preencoder = preencoder
  95. self.postencoder = postencoder
  96. self.encoder = encoder
  97. if not hasattr(self.encoder, "interctc_use_conditioning"):
  98. self.encoder.interctc_use_conditioning = False
  99. if self.encoder.interctc_use_conditioning:
  100. self.encoder.conditioning_layer = torch.nn.Linear(
  101. vocab_size, self.encoder.output_size()
  102. )
  103. self.error_calculator = None
  104. # we set self.decoder = None in the CTC mode since
  105. # self.decoder parameters were never used and PyTorch complained
  106. # and threw an Exception in the multi-GPU experiment.
  107. # thanks Jeff Farris for pointing out the issue.
  108. if ctc_weight == 1.0:
  109. self.decoder = None
  110. else:
  111. self.decoder = decoder
  112. self.criterion_att = LabelSmoothingLoss(
  113. size=vocab_size,
  114. padding_idx=ignore_id,
  115. smoothing=lsm_weight,
  116. normalize_length=length_normalized_loss,
  117. )
  118. if report_cer or report_wer:
  119. self.error_calculator = ErrorCalculator(
  120. token_list, sym_space, sym_blank, report_cer, report_wer
  121. )
  122. if ctc_weight == 0.0:
  123. self.ctc = None
  124. else:
  125. self.ctc = ctc
  126. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  127. self.predictor = predictor
  128. self.predictor_weight = predictor_weight
  129. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  130. self.step_cur = 0
  131. if self.encoder.overlap_chunk_cls is not None:
  132. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  133. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  134. self.decoder_attention_chunk_type = decoder_attention_chunk_type
  135. self.encoder2 = encoder2
  136. self.decoder2 = decoder2
  137. self.ctc_weight2 = ctc_weight2
  138. if ctc_weight2 == 0.0:
  139. self.ctc2 = None
  140. else:
  141. self.ctc2 = ctc2
  142. self.interctc_weight2 = interctc_weight2
  143. self.predictor2 = predictor2
  144. self.predictor_weight2 = predictor_weight2
  145. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  146. self.stride_conv = stride_conv
  147. self.loss_weight_model1 = loss_weight_model1
  148. if self.encoder2.overlap_chunk_cls is not None:
  149. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  150. self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
  151. self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
  152. self.enable_maas_finetune = enable_maas_finetune
  153. self.freeze_encoder2 = freeze_encoder2
  154. self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
  155. def forward(
  156. self,
  157. speech: torch.Tensor,
  158. speech_lengths: torch.Tensor,
  159. text: torch.Tensor,
  160. text_lengths: torch.Tensor,
  161. decoding_ind: int = None,
  162. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  163. """Frontend + Encoder + Decoder + Calc loss
  164. Args:
  165. speech: (Batch, Length, ...)
  166. speech_lengths: (Batch, )
  167. text: (Batch, Length)
  168. text_lengths: (Batch,)
  169. """
  170. assert text_lengths.dim() == 1, text_lengths.shape
  171. # Check that batch_size is unified
  172. assert (
  173. speech.shape[0]
  174. == speech_lengths.shape[0]
  175. == text.shape[0]
  176. == text_lengths.shape[0]
  177. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  178. batch_size = speech.shape[0]
  179. # for data-parallel
  180. text = text[:, : text_lengths.max()]
  181. speech = speech[:, :speech_lengths.max()]
  182. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  183. # 1. Encoder
  184. if self.enable_maas_finetune:
  185. with torch.no_grad():
  186. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  187. else:
  188. speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  189. intermediate_outs = None
  190. if isinstance(encoder_out, tuple):
  191. intermediate_outs = encoder_out[1]
  192. encoder_out = encoder_out[0]
  193. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  194. loss_ctc, cer_ctc = None, None
  195. stats = dict()
  196. loss_pre = None
  197. loss, loss1, loss2 = 0.0, 0.0, 0.0
  198. if self.loss_weight_model1 > 0.0:
  199. ## model1
  200. # 1. CTC branch
  201. if self.enable_maas_finetune:
  202. with torch.no_grad():
  203. if self.ctc_weight != 0.0:
  204. if self.encoder.overlap_chunk_cls is not None:
  205. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  206. encoder_out_lens,
  207. chunk_outs=None)
  208. loss_ctc, cer_ctc = self._calc_ctc_loss(
  209. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  210. )
  211. # Collect CTC branch stats
  212. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  213. stats["cer_ctc"] = cer_ctc
  214. # Intermediate CTC (optional)
  215. loss_interctc = 0.0
  216. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  217. for layer_idx, intermediate_out in intermediate_outs:
  218. # we assume intermediate_out has the same length & padding
  219. # as those of encoder_out
  220. if self.encoder.overlap_chunk_cls is not None:
  221. encoder_out_ctc, encoder_out_lens_ctc = \
  222. self.encoder.overlap_chunk_cls.remove_chunk(
  223. intermediate_out,
  224. encoder_out_lens,
  225. chunk_outs=None)
  226. loss_ic, cer_ic = self._calc_ctc_loss(
  227. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  228. )
  229. loss_interctc = loss_interctc + loss_ic
  230. # Collect Intermedaite CTC stats
  231. stats["loss_interctc_layer{}".format(layer_idx)] = (
  232. loss_ic.detach() if loss_ic is not None else None
  233. )
  234. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  235. loss_interctc = loss_interctc / len(intermediate_outs)
  236. # calculate whole encoder loss
  237. loss_ctc = (
  238. 1 - self.interctc_weight
  239. ) * loss_ctc + self.interctc_weight * loss_interctc
  240. # 2b. Attention decoder branch
  241. if self.ctc_weight != 1.0:
  242. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  243. encoder_out, encoder_out_lens, text, text_lengths
  244. )
  245. # 3. CTC-Att loss definition
  246. if self.ctc_weight == 0.0:
  247. loss = loss_att + loss_pre * self.predictor_weight
  248. elif self.ctc_weight == 1.0:
  249. loss = loss_ctc
  250. else:
  251. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  252. # Collect Attn branch stats
  253. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  254. stats["acc"] = acc_att
  255. stats["cer"] = cer_att
  256. stats["wer"] = wer_att
  257. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  258. else:
  259. if self.ctc_weight != 0.0:
  260. if self.encoder.overlap_chunk_cls is not None:
  261. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  262. encoder_out_lens,
  263. chunk_outs=None)
  264. loss_ctc, cer_ctc = self._calc_ctc_loss(
  265. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  266. )
  267. # Collect CTC branch stats
  268. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  269. stats["cer_ctc"] = cer_ctc
  270. # Intermediate CTC (optional)
  271. loss_interctc = 0.0
  272. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  273. for layer_idx, intermediate_out in intermediate_outs:
  274. # we assume intermediate_out has the same length & padding
  275. # as those of encoder_out
  276. if self.encoder.overlap_chunk_cls is not None:
  277. encoder_out_ctc, encoder_out_lens_ctc = \
  278. self.encoder.overlap_chunk_cls.remove_chunk(
  279. intermediate_out,
  280. encoder_out_lens,
  281. chunk_outs=None)
  282. loss_ic, cer_ic = self._calc_ctc_loss(
  283. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  284. )
  285. loss_interctc = loss_interctc + loss_ic
  286. # Collect Intermedaite CTC stats
  287. stats["loss_interctc_layer{}".format(layer_idx)] = (
  288. loss_ic.detach() if loss_ic is not None else None
  289. )
  290. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  291. loss_interctc = loss_interctc / len(intermediate_outs)
  292. # calculate whole encoder loss
  293. loss_ctc = (
  294. 1 - self.interctc_weight
  295. ) * loss_ctc + self.interctc_weight * loss_interctc
  296. # 2b. Attention decoder branch
  297. if self.ctc_weight != 1.0:
  298. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
  299. encoder_out, encoder_out_lens, text, text_lengths
  300. )
  301. # 3. CTC-Att loss definition
  302. if self.ctc_weight == 0.0:
  303. loss = loss_att + loss_pre * self.predictor_weight
  304. elif self.ctc_weight == 1.0:
  305. loss = loss_ctc
  306. else:
  307. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  308. # Collect Attn branch stats
  309. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  310. stats["acc"] = acc_att
  311. stats["cer"] = cer_att
  312. stats["wer"] = wer_att
  313. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  314. loss1 = loss
  315. if self.loss_weight_model1 < 1.0:
  316. ## model2
  317. # encoder2
  318. if self.freeze_encoder2:
  319. with torch.no_grad():
  320. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  321. else:
  322. encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=ind)
  323. intermediate_outs = None
  324. if isinstance(encoder_out, tuple):
  325. intermediate_outs = encoder_out[1]
  326. encoder_out = encoder_out[0]
  327. # CTC2
  328. if self.ctc_weight2 != 0.0:
  329. if self.encoder2.overlap_chunk_cls is not None:
  330. encoder_out_ctc, encoder_out_lens_ctc = \
  331. self.encoder2.overlap_chunk_cls.remove_chunk(
  332. encoder_out,
  333. encoder_out_lens,
  334. chunk_outs=None,
  335. )
  336. loss_ctc, cer_ctc = self._calc_ctc_loss2(
  337. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  338. )
  339. # Collect CTC branch stats
  340. stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
  341. stats["cer_ctc2"] = cer_ctc
  342. # Intermediate CTC (optional)
  343. loss_interctc = 0.0
  344. if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
  345. for layer_idx, intermediate_out in intermediate_outs:
  346. # we assume intermediate_out has the same length & padding
  347. # as those of encoder_out
  348. if self.encoder2.overlap_chunk_cls is not None:
  349. encoder_out_ctc, encoder_out_lens_ctc = \
  350. self.encoder2.overlap_chunk_cls.remove_chunk(
  351. intermediate_out,
  352. encoder_out_lens,
  353. chunk_outs=None)
  354. loss_ic, cer_ic = self._calc_ctc_loss2(
  355. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  356. )
  357. loss_interctc = loss_interctc + loss_ic
  358. # Collect Intermedaite CTC stats
  359. stats["loss_interctc_layer{}2".format(layer_idx)] = (
  360. loss_ic.detach() if loss_ic is not None else None
  361. )
  362. stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
  363. loss_interctc = loss_interctc / len(intermediate_outs)
  364. # calculate whole encoder loss
  365. loss_ctc = (
  366. 1 - self.interctc_weight2
  367. ) * loss_ctc + self.interctc_weight2 * loss_interctc
  368. # 2b. Attention decoder branch
  369. if self.ctc_weight2 != 1.0:
  370. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
  371. encoder_out, encoder_out_lens, text, text_lengths
  372. )
  373. # 3. CTC-Att loss definition
  374. if self.ctc_weight2 == 0.0:
  375. loss = loss_att + loss_pre * self.predictor_weight2
  376. elif self.ctc_weight2 == 1.0:
  377. loss = loss_ctc
  378. else:
  379. loss = self.ctc_weight2 * loss_ctc + (
  380. 1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
  381. # Collect Attn branch stats
  382. stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
  383. stats["acc2"] = acc_att
  384. stats["cer2"] = cer_att
  385. stats["wer2"] = wer_att
  386. stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
  387. loss2 = loss
  388. loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
  389. stats["loss1"] = torch.clone(loss1.detach())
  390. stats["loss2"] = torch.clone(loss2.detach())
  391. stats["loss"] = torch.clone(loss.detach())
  392. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  393. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  394. return loss, stats, weight
  395. def collect_feats(
  396. self,
  397. speech: torch.Tensor,
  398. speech_lengths: torch.Tensor,
  399. text: torch.Tensor,
  400. text_lengths: torch.Tensor,
  401. ) -> Dict[str, torch.Tensor]:
  402. if self.extract_feats_in_collect_stats:
  403. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  404. else:
  405. # Generate dummy stats if extract_feats_in_collect_stats is False
  406. logging.warning(
  407. "Generating dummy stats for feats and feats_lengths, "
  408. "because encoder_conf.extract_feats_in_collect_stats is "
  409. f"{self.extract_feats_in_collect_stats}"
  410. )
  411. feats, feats_lengths = speech, speech_lengths
  412. return {"feats": feats, "feats_lengths": feats_lengths}
  413. def encode(
  414. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  415. ) -> Tuple[torch.Tensor, torch.Tensor]:
  416. """Frontend + Encoder. Note that this method is used by asr_inference.py
  417. Args:
  418. speech: (Batch, Length, ...)
  419. speech_lengths: (Batch, )
  420. """
  421. with autocast(False):
  422. # 1. Extract feats
  423. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  424. # 2. Data augmentation
  425. if self.specaug is not None and self.training:
  426. feats, feats_lengths = self.specaug(feats, feats_lengths)
  427. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  428. if self.normalize is not None:
  429. feats, feats_lengths = self.normalize(feats, feats_lengths)
  430. speech_raw = feats.clone().to(feats.device)
  431. # Pre-encoder, e.g. used for raw input data
  432. if self.preencoder is not None:
  433. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  434. # 4. Forward encoder
  435. # feats: (Batch, Length, Dim)
  436. # -> encoder_out: (Batch, Length2, Dim2)
  437. if self.encoder.interctc_use_conditioning:
  438. encoder_out, encoder_out_lens, _ = self.encoder(
  439. feats, feats_lengths, ctc=self.ctc, ind=ind
  440. )
  441. else:
  442. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  443. intermediate_outs = None
  444. if isinstance(encoder_out, tuple):
  445. intermediate_outs = encoder_out[1]
  446. encoder_out = encoder_out[0]
  447. # Post-encoder, e.g. NLU
  448. if self.postencoder is not None:
  449. encoder_out, encoder_out_lens = self.postencoder(
  450. encoder_out, encoder_out_lens
  451. )
  452. assert encoder_out.size(0) == speech.size(0), (
  453. encoder_out.size(),
  454. speech.size(0),
  455. )
  456. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  457. encoder_out.size(),
  458. encoder_out_lens.max(),
  459. )
  460. if intermediate_outs is not None:
  461. return (encoder_out, intermediate_outs), encoder_out_lens
  462. return speech_raw, encoder_out, encoder_out_lens
  463. def encode2(
  464. self,
  465. encoder_out: torch.Tensor,
  466. encoder_out_lens: torch.Tensor,
  467. speech: torch.Tensor,
  468. speech_lengths: torch.Tensor,
  469. ind: int = 0,
  470. ) -> Tuple[torch.Tensor, torch.Tensor]:
  471. """Frontend + Encoder. Note that this method is used by asr_inference.py
  472. Args:
  473. speech: (Batch, Length, ...)
  474. speech_lengths: (Batch, )
  475. """
  476. # with autocast(False):
  477. # # 1. Extract feats
  478. # feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  479. #
  480. # # 2. Data augmentation
  481. # if self.specaug is not None and self.training:
  482. # feats, feats_lengths = self.specaug(feats, feats_lengths)
  483. #
  484. # # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  485. # if self.normalize is not None:
  486. # feats, feats_lengths = self.normalize(feats, feats_lengths)
  487. # Pre-encoder, e.g. used for raw input data
  488. # if self.preencoder is not None:
  489. # feats, feats_lengths = self.preencoder(feats, feats_lengths)
  490. encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
  491. encoder_out,
  492. encoder_out_lens,
  493. chunk_outs=None,
  494. )
  495. # residual_input
  496. encoder_out = torch.cat((speech, encoder_out_rm), dim=-1)
  497. encoder_out_lens = encoder_out_lens_rm
  498. if self.stride_conv is not None:
  499. speech, speech_lengths = self.stride_conv(encoder_out, encoder_out_lens)
  500. if not self.encoder1_encoder2_joint_training:
  501. speech = speech.detach()
  502. speech_lengths = speech_lengths.detach()
  503. # 4. Forward encoder
  504. # feats: (Batch, Length, Dim)
  505. # -> encoder_out: (Batch, Length2, Dim2)
  506. if self.encoder2.interctc_use_conditioning:
  507. encoder_out, encoder_out_lens, _ = self.encoder2(
  508. speech, speech_lengths, ctc=self.ctc2, ind=ind
  509. )
  510. else:
  511. encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
  512. intermediate_outs = None
  513. if isinstance(encoder_out, tuple):
  514. intermediate_outs = encoder_out[1]
  515. encoder_out = encoder_out[0]
  516. # # Post-encoder, e.g. NLU
  517. # if self.postencoder is not None:
  518. # encoder_out, encoder_out_lens = self.postencoder(
  519. # encoder_out, encoder_out_lens
  520. # )
  521. assert encoder_out.size(0) == speech.size(0), (
  522. encoder_out.size(),
  523. speech.size(0),
  524. )
  525. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  526. encoder_out.size(),
  527. encoder_out_lens.max(),
  528. )
  529. if intermediate_outs is not None:
  530. return (encoder_out, intermediate_outs), encoder_out_lens
  531. return encoder_out, encoder_out_lens
  532. def _extract_feats(
  533. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  534. ) -> Tuple[torch.Tensor, torch.Tensor]:
  535. assert speech_lengths.dim() == 1, speech_lengths.shape
  536. # for data-parallel
  537. speech = speech[:, : speech_lengths.max()]
  538. if self.frontend is not None:
  539. # Frontend
  540. # e.g. STFT and Feature extract
  541. # data_loader may send time-domain signal in this case
  542. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  543. feats, feats_lengths = self.frontend(speech, speech_lengths)
  544. else:
  545. # No frontend and no feature extract
  546. feats, feats_lengths = speech, speech_lengths
  547. return feats, feats_lengths
  548. def nll(
  549. self,
  550. encoder_out: torch.Tensor,
  551. encoder_out_lens: torch.Tensor,
  552. ys_pad: torch.Tensor,
  553. ys_pad_lens: torch.Tensor,
  554. ) -> torch.Tensor:
  555. """Compute negative log likelihood(nll) from transformer-decoder
  556. Normally, this function is called in batchify_nll.
  557. Args:
  558. encoder_out: (Batch, Length, Dim)
  559. encoder_out_lens: (Batch,)
  560. ys_pad: (Batch, Length)
  561. ys_pad_lens: (Batch,)
  562. """
  563. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  564. ys_in_lens = ys_pad_lens + 1
  565. # 1. Forward decoder
  566. decoder_out, _ = self.decoder(
  567. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  568. ) # [batch, seqlen, dim]
  569. batch_size = decoder_out.size(0)
  570. decoder_num_class = decoder_out.size(2)
  571. # nll: negative log-likelihood
  572. nll = torch.nn.functional.cross_entropy(
  573. decoder_out.view(-1, decoder_num_class),
  574. ys_out_pad.view(-1),
  575. ignore_index=self.ignore_id,
  576. reduction="none",
  577. )
  578. nll = nll.view(batch_size, -1)
  579. nll = nll.sum(dim=1)
  580. assert nll.size(0) == batch_size
  581. return nll
  582. def batchify_nll(
  583. self,
  584. encoder_out: torch.Tensor,
  585. encoder_out_lens: torch.Tensor,
  586. ys_pad: torch.Tensor,
  587. ys_pad_lens: torch.Tensor,
  588. batch_size: int = 100,
  589. ):
  590. """Compute negative log likelihood(nll) from transformer-decoder
  591. To avoid OOM, this fuction seperate the input into batches.
  592. Then call nll for each batch and combine and return results.
  593. Args:
  594. encoder_out: (Batch, Length, Dim)
  595. encoder_out_lens: (Batch,)
  596. ys_pad: (Batch, Length)
  597. ys_pad_lens: (Batch,)
  598. batch_size: int, samples each batch contain when computing nll,
  599. you may change this to avoid OOM or increase
  600. GPU memory usage
  601. """
  602. total_num = encoder_out.size(0)
  603. if total_num <= batch_size:
  604. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  605. else:
  606. nll = []
  607. start_idx = 0
  608. while True:
  609. end_idx = min(start_idx + batch_size, total_num)
  610. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  611. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  612. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  613. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  614. batch_nll = self.nll(
  615. batch_encoder_out,
  616. batch_encoder_out_lens,
  617. batch_ys_pad,
  618. batch_ys_pad_lens,
  619. )
  620. nll.append(batch_nll)
  621. start_idx = end_idx
  622. if start_idx == total_num:
  623. break
  624. nll = torch.cat(nll)
  625. assert nll.size(0) == total_num
  626. return nll
  627. def _calc_att_loss(
  628. self,
  629. encoder_out: torch.Tensor,
  630. encoder_out_lens: torch.Tensor,
  631. ys_pad: torch.Tensor,
  632. ys_pad_lens: torch.Tensor,
  633. ):
  634. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  635. ys_in_lens = ys_pad_lens + 1
  636. # 1. Forward decoder
  637. decoder_out, _ = self.decoder(
  638. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  639. )
  640. # 2. Compute attention loss
  641. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  642. acc_att = th_accuracy(
  643. decoder_out.view(-1, self.vocab_size),
  644. ys_out_pad,
  645. ignore_label=self.ignore_id,
  646. )
  647. # Compute cer/wer using attention-decoder
  648. if self.training or self.error_calculator is None:
  649. cer_att, wer_att = None, None
  650. else:
  651. ys_hat = decoder_out.argmax(dim=-1)
  652. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  653. return loss_att, acc_att, cer_att, wer_att
  654. def _calc_att_predictor_loss(
  655. self,
  656. encoder_out: torch.Tensor,
  657. encoder_out_lens: torch.Tensor,
  658. ys_pad: torch.Tensor,
  659. ys_pad_lens: torch.Tensor,
  660. ):
  661. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  662. ys_in_lens = ys_pad_lens + 1
  663. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  664. device=encoder_out.device)[:, None, :]
  665. mask_chunk_predictor = None
  666. if self.encoder.overlap_chunk_cls is not None:
  667. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  668. device=encoder_out.device,
  669. batch_size=encoder_out.size(
  670. 0))
  671. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  672. batch_size=encoder_out.size(0))
  673. encoder_out = encoder_out * mask_shfit_chunk
  674. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  675. ys_out_pad,
  676. encoder_out_mask,
  677. ignore_id=self.ignore_id,
  678. mask_chunk_predictor=mask_chunk_predictor,
  679. target_label_length=ys_in_lens,
  680. )
  681. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  682. encoder_out_lens)
  683. scama_mask = None
  684. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  685. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  686. attention_chunk_center_bias = 0
  687. attention_chunk_size = encoder_chunk_size
  688. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  689. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  690. device=encoder_out.device,
  691. batch_size=encoder_out.size(
  692. 0))
  693. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  694. predictor_alignments=predictor_alignments,
  695. encoder_sequence_length=encoder_out_lens,
  696. chunk_size=1,
  697. encoder_chunk_size=encoder_chunk_size,
  698. attention_chunk_center_bias=attention_chunk_center_bias,
  699. attention_chunk_size=attention_chunk_size,
  700. attention_chunk_type=self.decoder_attention_chunk_type,
  701. step=None,
  702. predictor_mask_chunk_hopping=mask_chunk_predictor,
  703. decoder_att_look_back_factor=decoder_att_look_back_factor,
  704. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  705. target_length=ys_in_lens,
  706. is_training=self.training,
  707. )
  708. elif self.encoder.overlap_chunk_cls is not None:
  709. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  710. chunk_outs=None)
  711. # try:
  712. # 1. Forward decoder
  713. decoder_out, _ = self.decoder(
  714. encoder_out,
  715. encoder_out_lens,
  716. ys_in_pad,
  717. ys_in_lens,
  718. chunk_mask=scama_mask,
  719. pre_acoustic_embeds=pre_acoustic_embeds,
  720. )
  721. # 2. Compute attention loss
  722. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  723. acc_att = th_accuracy(
  724. decoder_out.view(-1, self.vocab_size),
  725. ys_out_pad,
  726. ignore_label=self.ignore_id,
  727. )
  728. # predictor loss
  729. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  730. # Compute cer/wer using attention-decoder
  731. if self.training or self.error_calculator is None:
  732. cer_att, wer_att = None, None
  733. else:
  734. ys_hat = decoder_out.argmax(dim=-1)
  735. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  736. return loss_att, acc_att, cer_att, wer_att, loss_pre
  737. def _calc_att_predictor_loss2(
  738. self,
  739. encoder_out: torch.Tensor,
  740. encoder_out_lens: torch.Tensor,
  741. ys_pad: torch.Tensor,
  742. ys_pad_lens: torch.Tensor,
  743. ):
  744. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  745. ys_in_lens = ys_pad_lens + 1
  746. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  747. device=encoder_out.device)[:, None, :]
  748. mask_chunk_predictor = None
  749. if self.encoder2.overlap_chunk_cls is not None:
  750. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  751. device=encoder_out.device,
  752. batch_size=encoder_out.size(
  753. 0))
  754. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  755. batch_size=encoder_out.size(0))
  756. encoder_out = encoder_out * mask_shfit_chunk
  757. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  758. ys_out_pad,
  759. encoder_out_mask,
  760. ignore_id=self.ignore_id,
  761. mask_chunk_predictor=mask_chunk_predictor,
  762. target_label_length=ys_in_lens,
  763. )
  764. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  765. encoder_out_lens)
  766. scama_mask = None
  767. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  768. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  769. attention_chunk_center_bias = 0
  770. attention_chunk_size = encoder_chunk_size
  771. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  772. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  773. device=encoder_out.device,
  774. batch_size=encoder_out.size(
  775. 0))
  776. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  777. predictor_alignments=predictor_alignments,
  778. encoder_sequence_length=encoder_out_lens,
  779. chunk_size=1,
  780. encoder_chunk_size=encoder_chunk_size,
  781. attention_chunk_center_bias=attention_chunk_center_bias,
  782. attention_chunk_size=attention_chunk_size,
  783. attention_chunk_type=self.decoder_attention_chunk_type2,
  784. step=None,
  785. predictor_mask_chunk_hopping=mask_chunk_predictor,
  786. decoder_att_look_back_factor=decoder_att_look_back_factor,
  787. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  788. target_length=ys_in_lens,
  789. is_training=self.training,
  790. )
  791. elif self.encoder2.overlap_chunk_cls is not None:
  792. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  793. chunk_outs=None)
  794. # try:
  795. # 1. Forward decoder
  796. decoder_out, _ = self.decoder2(
  797. encoder_out,
  798. encoder_out_lens,
  799. ys_in_pad,
  800. ys_in_lens,
  801. chunk_mask=scama_mask,
  802. pre_acoustic_embeds=pre_acoustic_embeds,
  803. )
  804. # 2. Compute attention loss
  805. loss_att = self.criterion_att(decoder_out, ys_out_pad)
  806. acc_att = th_accuracy(
  807. decoder_out.view(-1, self.vocab_size),
  808. ys_out_pad,
  809. ignore_label=self.ignore_id,
  810. )
  811. # predictor loss
  812. loss_pre = self.criterion_pre(ys_in_lens.type_as(pre_token_length), pre_token_length)
  813. # Compute cer/wer using attention-decoder
  814. if self.training or self.error_calculator is None:
  815. cer_att, wer_att = None, None
  816. else:
  817. ys_hat = decoder_out.argmax(dim=-1)
  818. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  819. return loss_att, acc_att, cer_att, wer_att, loss_pre
  820. def calc_predictor_mask(
  821. self,
  822. encoder_out: torch.Tensor,
  823. encoder_out_lens: torch.Tensor,
  824. ys_pad: torch.Tensor = None,
  825. ys_pad_lens: torch.Tensor = None,
  826. ):
  827. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  828. # ys_in_lens = ys_pad_lens + 1
  829. ys_out_pad, ys_in_lens = None, None
  830. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  831. device=encoder_out.device)[:, None, :]
  832. mask_chunk_predictor = None
  833. if self.encoder.overlap_chunk_cls is not None:
  834. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  835. device=encoder_out.device,
  836. batch_size=encoder_out.size(
  837. 0))
  838. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  839. batch_size=encoder_out.size(0))
  840. encoder_out = encoder_out * mask_shfit_chunk
  841. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  842. ys_out_pad,
  843. encoder_out_mask,
  844. ignore_id=self.ignore_id,
  845. mask_chunk_predictor=mask_chunk_predictor,
  846. target_label_length=ys_in_lens,
  847. )
  848. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  849. encoder_out_lens)
  850. scama_mask = None
  851. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  852. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  853. attention_chunk_center_bias = 0
  854. attention_chunk_size = encoder_chunk_size
  855. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  856. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  857. device=encoder_out.device,
  858. batch_size=encoder_out.size(
  859. 0))
  860. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  861. predictor_alignments=predictor_alignments,
  862. encoder_sequence_length=encoder_out_lens,
  863. chunk_size=1,
  864. encoder_chunk_size=encoder_chunk_size,
  865. attention_chunk_center_bias=attention_chunk_center_bias,
  866. attention_chunk_size=attention_chunk_size,
  867. attention_chunk_type=self.decoder_attention_chunk_type,
  868. step=None,
  869. predictor_mask_chunk_hopping=mask_chunk_predictor,
  870. decoder_att_look_back_factor=decoder_att_look_back_factor,
  871. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  872. target_length=ys_in_lens,
  873. is_training=self.training,
  874. )
  875. elif self.encoder.overlap_chunk_cls is not None:
  876. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  877. chunk_outs=None)
  878. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  879. def calc_predictor_mask2(
  880. self,
  881. encoder_out: torch.Tensor,
  882. encoder_out_lens: torch.Tensor,
  883. ys_pad: torch.Tensor = None,
  884. ys_pad_lens: torch.Tensor = None,
  885. ):
  886. # ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  887. # ys_in_lens = ys_pad_lens + 1
  888. ys_out_pad, ys_in_lens = None, None
  889. encoder_out_mask = sequence_mask(encoder_out_lens, maxlen=encoder_out.size(1), dtype=encoder_out.dtype,
  890. device=encoder_out.device)[:, None, :]
  891. mask_chunk_predictor = None
  892. if self.encoder2.overlap_chunk_cls is not None:
  893. mask_chunk_predictor = self.encoder2.overlap_chunk_cls.get_mask_chunk_predictor(None,
  894. device=encoder_out.device,
  895. batch_size=encoder_out.size(
  896. 0))
  897. mask_shfit_chunk = self.encoder2.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  898. batch_size=encoder_out.size(0))
  899. encoder_out = encoder_out * mask_shfit_chunk
  900. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor2(encoder_out,
  901. ys_out_pad,
  902. encoder_out_mask,
  903. ignore_id=self.ignore_id,
  904. mask_chunk_predictor=mask_chunk_predictor,
  905. target_label_length=ys_in_lens,
  906. )
  907. predictor_alignments, predictor_alignments_len = self.predictor2.gen_frame_alignments(pre_alphas,
  908. encoder_out_lens)
  909. scama_mask = None
  910. if self.encoder2.overlap_chunk_cls is not None and self.decoder_attention_chunk_type2 == 'chunk':
  911. encoder_chunk_size = self.encoder2.overlap_chunk_cls.chunk_size_pad_shift_cur
  912. attention_chunk_center_bias = 0
  913. attention_chunk_size = encoder_chunk_size
  914. decoder_att_look_back_factor = self.encoder2.overlap_chunk_cls.decoder_att_look_back_factor_cur
  915. mask_shift_att_chunk_decoder = self.encoder2.overlap_chunk_cls.get_mask_shift_att_chunk_decoder(None,
  916. device=encoder_out.device,
  917. batch_size=encoder_out.size(
  918. 0))
  919. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn2(
  920. predictor_alignments=predictor_alignments,
  921. encoder_sequence_length=encoder_out_lens,
  922. chunk_size=1,
  923. encoder_chunk_size=encoder_chunk_size,
  924. attention_chunk_center_bias=attention_chunk_center_bias,
  925. attention_chunk_size=attention_chunk_size,
  926. attention_chunk_type=self.decoder_attention_chunk_type2,
  927. step=None,
  928. predictor_mask_chunk_hopping=mask_chunk_predictor,
  929. decoder_att_look_back_factor=decoder_att_look_back_factor,
  930. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  931. target_length=ys_in_lens,
  932. is_training=self.training,
  933. )
  934. elif self.encoder2.overlap_chunk_cls is not None:
  935. encoder_out, encoder_out_lens = self.encoder2.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
  936. chunk_outs=None)
  937. return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
  938. def _calc_ctc_loss(
  939. self,
  940. encoder_out: torch.Tensor,
  941. encoder_out_lens: torch.Tensor,
  942. ys_pad: torch.Tensor,
  943. ys_pad_lens: torch.Tensor,
  944. ):
  945. # Calc CTC loss
  946. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  947. # Calc CER using CTC
  948. cer_ctc = None
  949. if not self.training and self.error_calculator is not None:
  950. ys_hat = self.ctc.argmax(encoder_out).data
  951. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  952. return loss_ctc, cer_ctc
  953. def _calc_ctc_loss2(
  954. self,
  955. encoder_out: torch.Tensor,
  956. encoder_out_lens: torch.Tensor,
  957. ys_pad: torch.Tensor,
  958. ys_pad_lens: torch.Tensor,
  959. ):
  960. # Calc CTC loss
  961. loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  962. # Calc CER using CTC
  963. cer_ctc = None
  964. if not self.training and self.error_calculator is not None:
  965. ys_hat = self.ctc2.argmax(encoder_out).data
  966. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  967. return loss_ctc, cer_ctc