e2e_asr_paraformer.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import random
  11. import numpy as np
  12. from typeguard import check_argument_types
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.e2e_asr_common import ErrorCalculator
  20. from funasr.models.encoder.abs_encoder import AbsEncoder
  21. from funasr.models.frontend.abs_frontend import AbsFrontend
  22. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  23. from funasr.models.predictor.cif import mae_loss
  24. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  25. from funasr.models.specaug.abs_specaug import AbsSpecAug
  26. from funasr.modules.add_sos_eos import add_sos_eos
  27. from funasr.modules.nets_utils import make_pad_mask, pad_list
  28. from funasr.modules.nets_utils import th_accuracy
  29. from funasr.torch_utils.device_funcs import force_gatherable
  30. from funasr.train.abs_espnet_model import AbsESPnetModel
  31. from funasr.models.predictor.cif import CifPredictorV3
  32. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  33. from torch.cuda.amp import autocast
  34. else:
  35. # Nothing to do if torch<1.6.0
  36. @contextmanager
  37. def autocast(enabled=True):
  38. yield
  39. class Paraformer(AbsESPnetModel):
  40. """
  41. Author: Speech Lab, Alibaba Group, China
  42. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  43. https://arxiv.org/abs/2206.08317
  44. """
  45. def __init__(
  46. self,
  47. vocab_size: int,
  48. token_list: Union[Tuple[str, ...], List[str]],
  49. frontend: Optional[AbsFrontend],
  50. specaug: Optional[AbsSpecAug],
  51. normalize: Optional[AbsNormalize],
  52. preencoder: Optional[AbsPreEncoder],
  53. encoder: AbsEncoder,
  54. postencoder: Optional[AbsPostEncoder],
  55. decoder: AbsDecoder,
  56. ctc: CTC,
  57. ctc_weight: float = 0.5,
  58. interctc_weight: float = 0.0,
  59. ignore_id: int = -1,
  60. blank_id: int = 0,
  61. sos: int = 1,
  62. eos: int = 2,
  63. lsm_weight: float = 0.0,
  64. length_normalized_loss: bool = False,
  65. report_cer: bool = True,
  66. report_wer: bool = True,
  67. sym_space: str = "<space>",
  68. sym_blank: str = "<blank>",
  69. extract_feats_in_collect_stats: bool = True,
  70. predictor=None,
  71. predictor_weight: float = 0.0,
  72. predictor_bias: int = 0,
  73. sampling_ratio: float = 0.2,
  74. share_embedding: bool = False,
  75. ):
  76. assert check_argument_types()
  77. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  78. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  79. super().__init__()
  80. # note that eos is the same as sos (equivalent ID)
  81. self.blank_id = blank_id
  82. self.sos = vocab_size - 1 if sos is None else sos
  83. self.eos = vocab_size - 1 if eos is None else eos
  84. self.vocab_size = vocab_size
  85. self.ignore_id = ignore_id
  86. self.ctc_weight = ctc_weight
  87. self.interctc_weight = interctc_weight
  88. self.token_list = token_list.copy()
  89. self.frontend = frontend
  90. self.specaug = specaug
  91. self.normalize = normalize
  92. self.preencoder = preencoder
  93. self.postencoder = postencoder
  94. self.encoder = encoder
  95. if not hasattr(self.encoder, "interctc_use_conditioning"):
  96. self.encoder.interctc_use_conditioning = False
  97. if self.encoder.interctc_use_conditioning:
  98. self.encoder.conditioning_layer = torch.nn.Linear(
  99. vocab_size, self.encoder.output_size()
  100. )
  101. self.error_calculator = None
  102. if ctc_weight == 1.0:
  103. self.decoder = None
  104. else:
  105. self.decoder = decoder
  106. self.criterion_att = LabelSmoothingLoss(
  107. size=vocab_size,
  108. padding_idx=ignore_id,
  109. smoothing=lsm_weight,
  110. normalize_length=length_normalized_loss,
  111. )
  112. if report_cer or report_wer:
  113. self.error_calculator = ErrorCalculator(
  114. token_list, sym_space, sym_blank, report_cer, report_wer
  115. )
  116. if ctc_weight == 0.0:
  117. self.ctc = None
  118. else:
  119. self.ctc = ctc
  120. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  121. self.predictor = predictor
  122. self.predictor_weight = predictor_weight
  123. self.predictor_bias = predictor_bias
  124. self.sampling_ratio = sampling_ratio
  125. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  126. self.step_cur = 0
  127. self.share_embedding = share_embedding
  128. if self.share_embedding:
  129. self.decoder.embed = None
  130. def forward(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  137. """Frontend + Encoder + Decoder + Calc loss
  138. Args:
  139. speech: (Batch, Length, ...)
  140. speech_lengths: (Batch, )
  141. text: (Batch, Length)
  142. text_lengths: (Batch,)
  143. """
  144. assert text_lengths.dim() == 1, text_lengths.shape
  145. # Check that batch_size is unified
  146. assert (
  147. speech.shape[0]
  148. == speech_lengths.shape[0]
  149. == text.shape[0]
  150. == text_lengths.shape[0]
  151. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  152. batch_size = speech.shape[0]
  153. self.step_cur += 1
  154. # for data-parallel
  155. text = text[:, : text_lengths.max()]
  156. speech = speech[:, :speech_lengths.max()]
  157. # 1. Encoder
  158. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  159. intermediate_outs = None
  160. if isinstance(encoder_out, tuple):
  161. intermediate_outs = encoder_out[1]
  162. encoder_out = encoder_out[0]
  163. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  164. loss_ctc, cer_ctc = None, None
  165. loss_pre = None
  166. stats = dict()
  167. # 1. CTC branch
  168. if self.ctc_weight != 0.0:
  169. loss_ctc, cer_ctc = self._calc_ctc_loss(
  170. encoder_out, encoder_out_lens, text, text_lengths
  171. )
  172. # Collect CTC branch stats
  173. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  174. stats["cer_ctc"] = cer_ctc
  175. # Intermediate CTC (optional)
  176. loss_interctc = 0.0
  177. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  178. for layer_idx, intermediate_out in intermediate_outs:
  179. # we assume intermediate_out has the same length & padding
  180. # as those of encoder_out
  181. loss_ic, cer_ic = self._calc_ctc_loss(
  182. intermediate_out, encoder_out_lens, text, text_lengths
  183. )
  184. loss_interctc = loss_interctc + loss_ic
  185. # Collect Intermedaite CTC stats
  186. stats["loss_interctc_layer{}".format(layer_idx)] = (
  187. loss_ic.detach() if loss_ic is not None else None
  188. )
  189. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  190. loss_interctc = loss_interctc / len(intermediate_outs)
  191. # calculate whole encoder loss
  192. loss_ctc = (
  193. 1 - self.interctc_weight
  194. ) * loss_ctc + self.interctc_weight * loss_interctc
  195. # 2b. Attention decoder branch
  196. if self.ctc_weight != 1.0:
  197. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  198. encoder_out, encoder_out_lens, text, text_lengths
  199. )
  200. # 3. CTC-Att loss definition
  201. if self.ctc_weight == 0.0:
  202. loss = loss_att + loss_pre * self.predictor_weight
  203. elif self.ctc_weight == 1.0:
  204. loss = loss_ctc
  205. else:
  206. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  207. # Collect Attn branch stats
  208. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  209. stats["acc"] = acc_att
  210. stats["cer"] = cer_att
  211. stats["wer"] = wer_att
  212. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  213. stats["loss"] = torch.clone(loss.detach())
  214. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  215. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  216. return loss, stats, weight
  217. def collect_feats(
  218. self,
  219. speech: torch.Tensor,
  220. speech_lengths: torch.Tensor,
  221. text: torch.Tensor,
  222. text_lengths: torch.Tensor,
  223. ) -> Dict[str, torch.Tensor]:
  224. if self.extract_feats_in_collect_stats:
  225. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  226. else:
  227. # Generate dummy stats if extract_feats_in_collect_stats is False
  228. logging.warning(
  229. "Generating dummy stats for feats and feats_lengths, "
  230. "because encoder_conf.extract_feats_in_collect_stats is "
  231. f"{self.extract_feats_in_collect_stats}"
  232. )
  233. feats, feats_lengths = speech, speech_lengths
  234. return {"feats": feats, "feats_lengths": feats_lengths}
  235. def encode(
  236. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  237. ) -> Tuple[torch.Tensor, torch.Tensor]:
  238. """Frontend + Encoder. Note that this method is used by asr_inference.py
  239. Args:
  240. speech: (Batch, Length, ...)
  241. speech_lengths: (Batch, )
  242. """
  243. with autocast(False):
  244. # 1. Extract feats
  245. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  246. # 2. Data augmentation
  247. if self.specaug is not None and self.training:
  248. feats, feats_lengths = self.specaug(feats, feats_lengths)
  249. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  250. if self.normalize is not None:
  251. feats, feats_lengths = self.normalize(feats, feats_lengths)
  252. # Pre-encoder, e.g. used for raw input data
  253. if self.preencoder is not None:
  254. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  255. # 4. Forward encoder
  256. # feats: (Batch, Length, Dim)
  257. # -> encoder_out: (Batch, Length2, Dim2)
  258. if self.encoder.interctc_use_conditioning:
  259. encoder_out, encoder_out_lens, _ = self.encoder(
  260. feats, feats_lengths, ctc=self.ctc
  261. )
  262. else:
  263. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  264. intermediate_outs = None
  265. if isinstance(encoder_out, tuple):
  266. intermediate_outs = encoder_out[1]
  267. encoder_out = encoder_out[0]
  268. # Post-encoder, e.g. NLU
  269. if self.postencoder is not None:
  270. encoder_out, encoder_out_lens = self.postencoder(
  271. encoder_out, encoder_out_lens
  272. )
  273. assert encoder_out.size(0) == speech.size(0), (
  274. encoder_out.size(),
  275. speech.size(0),
  276. )
  277. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  278. encoder_out.size(),
  279. encoder_out_lens.max(),
  280. )
  281. if intermediate_outs is not None:
  282. return (encoder_out, intermediate_outs), encoder_out_lens
  283. return encoder_out, encoder_out_lens
  284. def calc_predictor(self, encoder_out, encoder_out_lens):
  285. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  286. encoder_out.device)
  287. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
  288. ignore_id=self.ignore_id)
  289. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  290. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  291. decoder_outs = self.decoder(
  292. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  293. )
  294. decoder_out = decoder_outs[0]
  295. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  296. return decoder_out, ys_pad_lens
  297. def _extract_feats(
  298. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  299. ) -> Tuple[torch.Tensor, torch.Tensor]:
  300. assert speech_lengths.dim() == 1, speech_lengths.shape
  301. # for data-parallel
  302. speech = speech[:, : speech_lengths.max()]
  303. if self.frontend is not None:
  304. # Frontend
  305. # e.g. STFT and Feature extract
  306. # data_loader may send time-domain signal in this case
  307. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  308. feats, feats_lengths = self.frontend(speech, speech_lengths)
  309. else:
  310. # No frontend and no feature extract
  311. feats, feats_lengths = speech, speech_lengths
  312. return feats, feats_lengths
  313. def nll(
  314. self,
  315. encoder_out: torch.Tensor,
  316. encoder_out_lens: torch.Tensor,
  317. ys_pad: torch.Tensor,
  318. ys_pad_lens: torch.Tensor,
  319. ) -> torch.Tensor:
  320. """Compute negative log likelihood(nll) from transformer-decoder
  321. Normally, this function is called in batchify_nll.
  322. Args:
  323. encoder_out: (Batch, Length, Dim)
  324. encoder_out_lens: (Batch,)
  325. ys_pad: (Batch, Length)
  326. ys_pad_lens: (Batch,)
  327. """
  328. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  329. ys_in_lens = ys_pad_lens + 1
  330. # 1. Forward decoder
  331. decoder_out, _ = self.decoder(
  332. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  333. ) # [batch, seqlen, dim]
  334. batch_size = decoder_out.size(0)
  335. decoder_num_class = decoder_out.size(2)
  336. # nll: negative log-likelihood
  337. nll = torch.nn.functional.cross_entropy(
  338. decoder_out.view(-1, decoder_num_class),
  339. ys_out_pad.view(-1),
  340. ignore_index=self.ignore_id,
  341. reduction="none",
  342. )
  343. nll = nll.view(batch_size, -1)
  344. nll = nll.sum(dim=1)
  345. assert nll.size(0) == batch_size
  346. return nll
  347. def batchify_nll(
  348. self,
  349. encoder_out: torch.Tensor,
  350. encoder_out_lens: torch.Tensor,
  351. ys_pad: torch.Tensor,
  352. ys_pad_lens: torch.Tensor,
  353. batch_size: int = 100,
  354. ):
  355. """Compute negative log likelihood(nll) from transformer-decoder
  356. To avoid OOM, this fuction seperate the input into batches.
  357. Then call nll for each batch and combine and return results.
  358. Args:
  359. encoder_out: (Batch, Length, Dim)
  360. encoder_out_lens: (Batch,)
  361. ys_pad: (Batch, Length)
  362. ys_pad_lens: (Batch,)
  363. batch_size: int, samples each batch contain when computing nll,
  364. you may change this to avoid OOM or increase
  365. GPU memory usage
  366. """
  367. total_num = encoder_out.size(0)
  368. if total_num <= batch_size:
  369. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  370. else:
  371. nll = []
  372. start_idx = 0
  373. while True:
  374. end_idx = min(start_idx + batch_size, total_num)
  375. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  376. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  377. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  378. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  379. batch_nll = self.nll(
  380. batch_encoder_out,
  381. batch_encoder_out_lens,
  382. batch_ys_pad,
  383. batch_ys_pad_lens,
  384. )
  385. nll.append(batch_nll)
  386. start_idx = end_idx
  387. if start_idx == total_num:
  388. break
  389. nll = torch.cat(nll)
  390. assert nll.size(0) == total_num
  391. return nll
  392. def _calc_att_loss(
  393. self,
  394. encoder_out: torch.Tensor,
  395. encoder_out_lens: torch.Tensor,
  396. ys_pad: torch.Tensor,
  397. ys_pad_lens: torch.Tensor,
  398. ):
  399. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  400. encoder_out.device)
  401. if self.predictor_bias == 1:
  402. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  403. ys_pad_lens = ys_pad_lens + self.predictor_bias
  404. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  405. ignore_id=self.ignore_id)
  406. # 0. sampler
  407. decoder_out_1st = None
  408. if self.sampling_ratio > 0.0:
  409. if self.step_cur < 2:
  410. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  411. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  412. pre_acoustic_embeds)
  413. else:
  414. if self.step_cur < 2:
  415. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  416. sematic_embeds = pre_acoustic_embeds
  417. # 1. Forward decoder
  418. decoder_outs = self.decoder(
  419. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  420. )
  421. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  422. if decoder_out_1st is None:
  423. decoder_out_1st = decoder_out
  424. # 2. Compute attention loss
  425. loss_att = self.criterion_att(decoder_out, ys_pad)
  426. acc_att = th_accuracy(
  427. decoder_out_1st.view(-1, self.vocab_size),
  428. ys_pad,
  429. ignore_label=self.ignore_id,
  430. )
  431. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  432. # Compute cer/wer using attention-decoder
  433. if self.training or self.error_calculator is None:
  434. cer_att, wer_att = None, None
  435. else:
  436. ys_hat = decoder_out_1st.argmax(dim=-1)
  437. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  438. return loss_att, acc_att, cer_att, wer_att, loss_pre
  439. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  440. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  441. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  442. if self.share_embedding:
  443. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  444. else:
  445. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  446. with torch.no_grad():
  447. decoder_outs = self.decoder(
  448. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  449. )
  450. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  451. pred_tokens = decoder_out.argmax(-1)
  452. nonpad_positions = ys_pad.ne(self.ignore_id)
  453. seq_lens = (nonpad_positions).sum(1)
  454. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  455. input_mask = torch.ones_like(nonpad_positions)
  456. bsz, seq_len = ys_pad.size()
  457. for li in range(bsz):
  458. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  459. if target_num > 0:
  460. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  461. input_mask = input_mask.eq(1)
  462. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  463. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  464. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  465. input_mask_expand_dim, 0)
  466. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  467. def _calc_ctc_loss(
  468. self,
  469. encoder_out: torch.Tensor,
  470. encoder_out_lens: torch.Tensor,
  471. ys_pad: torch.Tensor,
  472. ys_pad_lens: torch.Tensor,
  473. ):
  474. # Calc CTC loss
  475. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  476. # Calc CER using CTC
  477. cer_ctc = None
  478. if not self.training and self.error_calculator is not None:
  479. ys_hat = self.ctc.argmax(encoder_out).data
  480. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  481. return loss_ctc, cer_ctc
  482. class ParaformerBert(Paraformer):
  483. """
  484. Author: Speech Lab, Alibaba Group, China
  485. Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
  486. """
  487. def __init__(
  488. self,
  489. vocab_size: int,
  490. token_list: Union[Tuple[str, ...], List[str]],
  491. frontend: Optional[AbsFrontend],
  492. specaug: Optional[AbsSpecAug],
  493. normalize: Optional[AbsNormalize],
  494. preencoder: Optional[AbsPreEncoder],
  495. encoder: AbsEncoder,
  496. postencoder: Optional[AbsPostEncoder],
  497. decoder: AbsDecoder,
  498. ctc: CTC,
  499. ctc_weight: float = 0.5,
  500. interctc_weight: float = 0.0,
  501. ignore_id: int = -1,
  502. blank_id: int = 0,
  503. sos: int = 1,
  504. eos: int = 2,
  505. lsm_weight: float = 0.0,
  506. length_normalized_loss: bool = False,
  507. report_cer: bool = True,
  508. report_wer: bool = True,
  509. sym_space: str = "<space>",
  510. sym_blank: str = "<blank>",
  511. extract_feats_in_collect_stats: bool = True,
  512. predictor=None,
  513. predictor_weight: float = 0.0,
  514. predictor_bias: int = 0,
  515. sampling_ratio: float = 0.2,
  516. embeds_id: int = 2,
  517. embeds_loss_weight: float = 0.0,
  518. embed_dims: int = 768,
  519. ):
  520. assert check_argument_types()
  521. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  522. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  523. super().__init__(
  524. vocab_size=vocab_size,
  525. token_list=token_list,
  526. frontend=frontend,
  527. specaug=specaug,
  528. normalize=normalize,
  529. preencoder=preencoder,
  530. encoder=encoder,
  531. postencoder=postencoder,
  532. decoder=decoder,
  533. ctc=ctc,
  534. ctc_weight=ctc_weight,
  535. interctc_weight=interctc_weight,
  536. ignore_id=ignore_id,
  537. blank_id=blank_id,
  538. sos=sos,
  539. eos=eos,
  540. lsm_weight=lsm_weight,
  541. length_normalized_loss=length_normalized_loss,
  542. report_cer=report_cer,
  543. report_wer=report_wer,
  544. sym_space=sym_space,
  545. sym_blank=sym_blank,
  546. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  547. predictor=predictor,
  548. predictor_weight=predictor_weight,
  549. predictor_bias=predictor_bias,
  550. sampling_ratio=sampling_ratio,
  551. )
  552. self.decoder.embeds_id = embeds_id
  553. decoder_attention_dim = self.decoder.attention_dim
  554. self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
  555. self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
  556. self.embeds_loss_weight = embeds_loss_weight
  557. self.length_normalized_loss = length_normalized_loss
  558. def _calc_embed_loss(self,
  559. ys_pad: torch.Tensor,
  560. ys_pad_lens: torch.Tensor,
  561. embed: torch.Tensor = None,
  562. embed_lengths: torch.Tensor = None,
  563. embeds_outputs: torch.Tensor = None,
  564. ):
  565. embeds_outputs = self.pro_nn(embeds_outputs)
  566. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  567. embeds_outputs *= tgt_mask # b x l x d
  568. embed *= tgt_mask # b x l x d
  569. cos_loss = 1.0 - self.cos(embeds_outputs, embed)
  570. cos_loss *= tgt_mask.squeeze(2)
  571. if self.length_normalized_loss:
  572. token_num_total = torch.sum(tgt_mask)
  573. else:
  574. token_num_total = tgt_mask.size()[0]
  575. cos_loss_total = torch.sum(cos_loss)
  576. cos_loss = cos_loss_total / token_num_total
  577. # print("cos_loss: {}".format(cos_loss))
  578. return cos_loss
  579. def _calc_att_loss(
  580. self,
  581. encoder_out: torch.Tensor,
  582. encoder_out_lens: torch.Tensor,
  583. ys_pad: torch.Tensor,
  584. ys_pad_lens: torch.Tensor,
  585. ):
  586. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  587. encoder_out.device)
  588. if self.predictor_bias == 1:
  589. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  590. ys_pad_lens = ys_pad_lens + self.predictor_bias
  591. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  592. ignore_id=self.ignore_id)
  593. # 0. sampler
  594. decoder_out_1st = None
  595. if self.sampling_ratio > 0.0:
  596. if self.step_cur < 2:
  597. logging.info(
  598. "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  599. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  600. pre_acoustic_embeds)
  601. else:
  602. if self.step_cur < 2:
  603. logging.info(
  604. "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  605. sematic_embeds = pre_acoustic_embeds
  606. # 1. Forward decoder
  607. decoder_outs = self.decoder(
  608. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  609. )
  610. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  611. embeds_outputs = None
  612. if len(decoder_outs) > 2:
  613. embeds_outputs = decoder_outs[2]
  614. if decoder_out_1st is None:
  615. decoder_out_1st = decoder_out
  616. # 2. Compute attention loss
  617. loss_att = self.criterion_att(decoder_out, ys_pad)
  618. acc_att = th_accuracy(
  619. decoder_out_1st.view(-1, self.vocab_size),
  620. ys_pad,
  621. ignore_label=self.ignore_id,
  622. )
  623. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  624. # Compute cer/wer using attention-decoder
  625. if self.training or self.error_calculator is None:
  626. cer_att, wer_att = None, None
  627. else:
  628. ys_hat = decoder_out_1st.argmax(dim=-1)
  629. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  630. return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
  631. def forward(
  632. self,
  633. speech: torch.Tensor,
  634. speech_lengths: torch.Tensor,
  635. text: torch.Tensor,
  636. text_lengths: torch.Tensor,
  637. embed: torch.Tensor = None,
  638. embed_lengths: torch.Tensor = None,
  639. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  640. """Frontend + Encoder + Decoder + Calc loss
  641. Args:
  642. speech: (Batch, Length, ...)
  643. speech_lengths: (Batch, )
  644. text: (Batch, Length)
  645. text_lengths: (Batch,)
  646. """
  647. assert text_lengths.dim() == 1, text_lengths.shape
  648. # Check that batch_size is unified
  649. assert (
  650. speech.shape[0]
  651. == speech_lengths.shape[0]
  652. == text.shape[0]
  653. == text_lengths.shape[0]
  654. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  655. batch_size = speech.shape[0]
  656. self.step_cur += 1
  657. # for data-parallel
  658. text = text[:, : text_lengths.max()]
  659. speech = speech[:, :speech_lengths.max(), :]
  660. if embed is not None:
  661. embed = embed[:, :embed_lengths.max(), :]
  662. # 1. Encoder
  663. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  664. intermediate_outs = None
  665. if isinstance(encoder_out, tuple):
  666. intermediate_outs = encoder_out[1]
  667. encoder_out = encoder_out[0]
  668. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  669. loss_ctc, cer_ctc = None, None
  670. loss_pre = 0.0
  671. cos_loss = 0.0
  672. stats = dict()
  673. # 1. CTC branch
  674. if self.ctc_weight != 0.0:
  675. loss_ctc, cer_ctc = self._calc_ctc_loss(
  676. encoder_out, encoder_out_lens, text, text_lengths
  677. )
  678. # Collect CTC branch stats
  679. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  680. stats["cer_ctc"] = cer_ctc
  681. # Intermediate CTC (optional)
  682. loss_interctc = 0.0
  683. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  684. for layer_idx, intermediate_out in intermediate_outs:
  685. # we assume intermediate_out has the same length & padding
  686. # as those of encoder_out
  687. loss_ic, cer_ic = self._calc_ctc_loss(
  688. intermediate_out, encoder_out_lens, text, text_lengths
  689. )
  690. loss_interctc = loss_interctc + loss_ic
  691. # Collect Intermedaite CTC stats
  692. stats["loss_interctc_layer{}".format(layer_idx)] = (
  693. loss_ic.detach() if loss_ic is not None else None
  694. )
  695. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  696. loss_interctc = loss_interctc / len(intermediate_outs)
  697. # calculate whole encoder loss
  698. loss_ctc = (
  699. 1 - self.interctc_weight
  700. ) * loss_ctc + self.interctc_weight * loss_interctc
  701. # 2b. Attention decoder branch
  702. if self.ctc_weight != 1.0:
  703. loss_ret = self._calc_att_loss(
  704. encoder_out, encoder_out_lens, text, text_lengths
  705. )
  706. loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
  707. loss_ret[4]
  708. embeds_outputs = None
  709. if len(loss_ret) > 5:
  710. embeds_outputs = loss_ret[5]
  711. if embeds_outputs is not None:
  712. cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
  713. # 3. CTC-Att loss definition
  714. if self.ctc_weight == 0.0:
  715. loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  716. elif self.ctc_weight == 1.0:
  717. loss = loss_ctc
  718. else:
  719. loss = self.ctc_weight * loss_ctc + (
  720. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  721. # Collect Attn branch stats
  722. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  723. stats["acc"] = acc_att
  724. stats["cer"] = cer_att
  725. stats["wer"] = wer_att
  726. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
  727. stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
  728. stats["loss"] = torch.clone(loss.detach())
  729. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  730. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  731. return loss, stats, weight
  732. class BiCifParaformer(Paraformer):
  733. """
  734. Paraformer model with an extra cif predictor
  735. to conduct accurate timestamp prediction
  736. """
  737. def __init__(
  738. self,
  739. vocab_size: int,
  740. token_list: Union[Tuple[str, ...], List[str]],
  741. frontend: Optional[AbsFrontend],
  742. specaug: Optional[AbsSpecAug],
  743. normalize: Optional[AbsNormalize],
  744. preencoder: Optional[AbsPreEncoder],
  745. encoder: AbsEncoder,
  746. postencoder: Optional[AbsPostEncoder],
  747. decoder: AbsDecoder,
  748. ctc: CTC,
  749. ctc_weight: float = 0.5,
  750. interctc_weight: float = 0.0,
  751. ignore_id: int = -1,
  752. blank_id: int = 0,
  753. sos: int = 1,
  754. eos: int = 2,
  755. lsm_weight: float = 0.0,
  756. length_normalized_loss: bool = False,
  757. report_cer: bool = True,
  758. report_wer: bool = True,
  759. sym_space: str = "<space>",
  760. sym_blank: str = "<blank>",
  761. extract_feats_in_collect_stats: bool = True,
  762. predictor = None,
  763. predictor_weight: float = 0.0,
  764. predictor_bias: int = 0,
  765. sampling_ratio: float = 0.2,
  766. ):
  767. assert check_argument_types()
  768. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  769. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  770. super().__init__(
  771. vocab_size=vocab_size,
  772. token_list=token_list,
  773. frontend=frontend,
  774. specaug=specaug,
  775. normalize=normalize,
  776. preencoder=preencoder,
  777. encoder=encoder,
  778. postencoder=postencoder,
  779. decoder=decoder,
  780. ctc=ctc,
  781. ctc_weight=ctc_weight,
  782. interctc_weight=interctc_weight,
  783. ignore_id=ignore_id,
  784. blank_id=blank_id,
  785. sos=sos,
  786. eos=eos,
  787. lsm_weight=lsm_weight,
  788. length_normalized_loss=length_normalized_loss,
  789. report_cer=report_cer,
  790. report_wer=report_wer,
  791. sym_space=sym_space,
  792. sym_blank=sym_blank,
  793. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  794. predictor=predictor,
  795. predictor_weight=predictor_weight,
  796. predictor_bias=predictor_bias,
  797. sampling_ratio=sampling_ratio,
  798. )
  799. assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
  800. def _calc_pre2_loss(
  801. self,
  802. encoder_out: torch.Tensor,
  803. encoder_out_lens: torch.Tensor,
  804. ys_pad: torch.Tensor,
  805. ys_pad_lens: torch.Tensor,
  806. ):
  807. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  808. encoder_out.device)
  809. if self.predictor_bias == 1:
  810. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  811. ys_pad_lens = ys_pad_lens + self.predictor_bias
  812. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  813. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  814. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  815. return loss_pre2
  816. def calc_predictor(self, encoder_out, encoder_out_lens):
  817. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  818. encoder_out.device)
  819. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
  820. ignore_id=self.ignore_id)
  821. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  822. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  823. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  824. encoder_out.device)
  825. ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
  826. encoder_out_mask,
  827. token_num)
  828. import pdb; pdb.set_trace()
  829. return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
  830. def forward(
  831. self,
  832. speech: torch.Tensor,
  833. speech_lengths: torch.Tensor,
  834. text: torch.Tensor,
  835. text_lengths: torch.Tensor,
  836. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  837. """Frontend + Encoder + Decoder + Calc loss
  838. Args:
  839. speech: (Batch, Length, ...)
  840. speech_lengths: (Batch, )
  841. text: (Batch, Length)
  842. text_lengths: (Batch,)
  843. """
  844. assert text_lengths.dim() == 1, text_lengths.shape
  845. # Check that batch_size is unified
  846. assert (
  847. speech.shape[0]
  848. == speech_lengths.shape[0]
  849. == text.shape[0]
  850. == text_lengths.shape[0]
  851. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  852. batch_size = speech.shape[0]
  853. self.step_cur += 1
  854. # for data-parallel
  855. text = text[:, : text_lengths.max()]
  856. speech = speech[:, :speech_lengths.max()]
  857. # 1. Encoder
  858. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  859. stats = dict()
  860. loss_pre2 = self._calc_pre2_loss(
  861. encoder_out, encoder_out_lens, text, text_lengths
  862. )
  863. loss = loss_pre2
  864. stats["loss_pre2"] = loss_pre2.detach().cpu()
  865. stats["loss"] = torch.clone(loss.detach())
  866. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  867. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  868. return loss, stats, weight