e2e_asr_paraformer.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import random
  11. import numpy as np
  12. from typeguard import check_argument_types
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.e2e_asr_common import ErrorCalculator
  20. from funasr.models.encoder.abs_encoder import AbsEncoder
  21. from funasr.models.frontend.abs_frontend import AbsFrontend
  22. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  23. from funasr.models.predictor.cif import mae_loss
  24. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  25. from funasr.models.specaug.abs_specaug import AbsSpecAug
  26. from funasr.modules.add_sos_eos import add_sos_eos
  27. from funasr.modules.nets_utils import make_pad_mask, pad_list
  28. from funasr.modules.nets_utils import th_accuracy
  29. from funasr.torch_utils.device_funcs import force_gatherable
  30. from funasr.train.abs_espnet_model import AbsESPnetModel
  31. from funasr.models.predictor.cif import CifPredictorV3
  32. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  33. from torch.cuda.amp import autocast
  34. else:
  35. # Nothing to do if torch<1.6.0
  36. @contextmanager
  37. def autocast(enabled=True):
  38. yield
  39. class Paraformer(AbsESPnetModel):
  40. """
  41. Author: Speech Lab, Alibaba Group, China
  42. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  43. https://arxiv.org/abs/2206.08317
  44. """
  45. def __init__(
  46. self,
  47. vocab_size: int,
  48. token_list: Union[Tuple[str, ...], List[str]],
  49. frontend: Optional[AbsFrontend],
  50. specaug: Optional[AbsSpecAug],
  51. normalize: Optional[AbsNormalize],
  52. preencoder: Optional[AbsPreEncoder],
  53. encoder: AbsEncoder,
  54. postencoder: Optional[AbsPostEncoder],
  55. decoder: AbsDecoder,
  56. ctc: CTC,
  57. ctc_weight: float = 0.5,
  58. interctc_weight: float = 0.0,
  59. ignore_id: int = -1,
  60. blank_id: int = 0,
  61. sos: int = 1,
  62. eos: int = 2,
  63. lsm_weight: float = 0.0,
  64. length_normalized_loss: bool = False,
  65. report_cer: bool = True,
  66. report_wer: bool = True,
  67. sym_space: str = "<space>",
  68. sym_blank: str = "<blank>",
  69. extract_feats_in_collect_stats: bool = True,
  70. predictor=None,
  71. predictor_weight: float = 0.0,
  72. predictor_bias: int = 0,
  73. sampling_ratio: float = 0.2,
  74. share_embedding: bool = False,
  75. ):
  76. assert check_argument_types()
  77. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  78. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  79. super().__init__()
  80. # note that eos is the same as sos (equivalent ID)
  81. self.blank_id = blank_id
  82. self.sos = vocab_size - 1 if sos is None else sos
  83. self.eos = vocab_size - 1 if eos is None else eos
  84. self.vocab_size = vocab_size
  85. self.ignore_id = ignore_id
  86. self.ctc_weight = ctc_weight
  87. self.interctc_weight = interctc_weight
  88. self.token_list = token_list.copy()
  89. self.frontend = frontend
  90. self.specaug = specaug
  91. self.normalize = normalize
  92. self.preencoder = preencoder
  93. self.postencoder = postencoder
  94. self.encoder = encoder
  95. if not hasattr(self.encoder, "interctc_use_conditioning"):
  96. self.encoder.interctc_use_conditioning = False
  97. if self.encoder.interctc_use_conditioning:
  98. self.encoder.conditioning_layer = torch.nn.Linear(
  99. vocab_size, self.encoder.output_size()
  100. )
  101. self.error_calculator = None
  102. if ctc_weight == 1.0:
  103. self.decoder = None
  104. else:
  105. self.decoder = decoder
  106. self.criterion_att = LabelSmoothingLoss(
  107. size=vocab_size,
  108. padding_idx=ignore_id,
  109. smoothing=lsm_weight,
  110. normalize_length=length_normalized_loss,
  111. )
  112. if report_cer or report_wer:
  113. self.error_calculator = ErrorCalculator(
  114. token_list, sym_space, sym_blank, report_cer, report_wer
  115. )
  116. if ctc_weight == 0.0:
  117. self.ctc = None
  118. else:
  119. self.ctc = ctc
  120. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  121. self.predictor = predictor
  122. self.predictor_weight = predictor_weight
  123. self.predictor_bias = predictor_bias
  124. self.sampling_ratio = sampling_ratio
  125. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  126. self.step_cur = 0
  127. self.share_embedding = share_embedding
  128. if self.share_embedding:
  129. self.decoder.embed = None
  130. def forward(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  137. """Frontend + Encoder + Decoder + Calc loss
  138. Args:
  139. speech: (Batch, Length, ...)
  140. speech_lengths: (Batch, )
  141. text: (Batch, Length)
  142. text_lengths: (Batch,)
  143. """
  144. assert text_lengths.dim() == 1, text_lengths.shape
  145. # Check that batch_size is unified
  146. assert (
  147. speech.shape[0]
  148. == speech_lengths.shape[0]
  149. == text.shape[0]
  150. == text_lengths.shape[0]
  151. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  152. batch_size = speech.shape[0]
  153. self.step_cur += 1
  154. # for data-parallel
  155. text = text[:, : text_lengths.max()]
  156. speech = speech[:, :speech_lengths.max()]
  157. # 1. Encoder
  158. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  159. intermediate_outs = None
  160. if isinstance(encoder_out, tuple):
  161. intermediate_outs = encoder_out[1]
  162. encoder_out = encoder_out[0]
  163. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  164. loss_ctc, cer_ctc = None, None
  165. loss_pre = None
  166. stats = dict()
  167. # 1. CTC branch
  168. if self.ctc_weight != 0.0:
  169. loss_ctc, cer_ctc = self._calc_ctc_loss(
  170. encoder_out, encoder_out_lens, text, text_lengths
  171. )
  172. # Collect CTC branch stats
  173. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  174. stats["cer_ctc"] = cer_ctc
  175. # Intermediate CTC (optional)
  176. loss_interctc = 0.0
  177. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  178. for layer_idx, intermediate_out in intermediate_outs:
  179. # we assume intermediate_out has the same length & padding
  180. # as those of encoder_out
  181. loss_ic, cer_ic = self._calc_ctc_loss(
  182. intermediate_out, encoder_out_lens, text, text_lengths
  183. )
  184. loss_interctc = loss_interctc + loss_ic
  185. # Collect Intermedaite CTC stats
  186. stats["loss_interctc_layer{}".format(layer_idx)] = (
  187. loss_ic.detach() if loss_ic is not None else None
  188. )
  189. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  190. loss_interctc = loss_interctc / len(intermediate_outs)
  191. # calculate whole encoder loss
  192. loss_ctc = (
  193. 1 - self.interctc_weight
  194. ) * loss_ctc + self.interctc_weight * loss_interctc
  195. # 2b. Attention decoder branch
  196. if self.ctc_weight != 1.0:
  197. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  198. encoder_out, encoder_out_lens, text, text_lengths
  199. )
  200. # 3. CTC-Att loss definition
  201. if self.ctc_weight == 0.0:
  202. loss = loss_att + loss_pre * self.predictor_weight
  203. elif self.ctc_weight == 1.0:
  204. loss = loss_ctc
  205. else:
  206. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  207. # Collect Attn branch stats
  208. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  209. stats["acc"] = acc_att
  210. stats["cer"] = cer_att
  211. stats["wer"] = wer_att
  212. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  213. stats["loss"] = torch.clone(loss.detach())
  214. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  215. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  216. return loss, stats, weight
  217. def collect_feats(
  218. self,
  219. speech: torch.Tensor,
  220. speech_lengths: torch.Tensor,
  221. text: torch.Tensor,
  222. text_lengths: torch.Tensor,
  223. ) -> Dict[str, torch.Tensor]:
  224. if self.extract_feats_in_collect_stats:
  225. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  226. else:
  227. # Generate dummy stats if extract_feats_in_collect_stats is False
  228. logging.warning(
  229. "Generating dummy stats for feats and feats_lengths, "
  230. "because encoder_conf.extract_feats_in_collect_stats is "
  231. f"{self.extract_feats_in_collect_stats}"
  232. )
  233. feats, feats_lengths = speech, speech_lengths
  234. return {"feats": feats, "feats_lengths": feats_lengths}
  235. def encode(
  236. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  237. ) -> Tuple[torch.Tensor, torch.Tensor]:
  238. """Frontend + Encoder. Note that this method is used by asr_inference.py
  239. Args:
  240. speech: (Batch, Length, ...)
  241. speech_lengths: (Batch, )
  242. """
  243. with autocast(False):
  244. # 1. Extract feats
  245. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  246. # 2. Data augmentation
  247. if self.specaug is not None and self.training:
  248. feats, feats_lengths = self.specaug(feats, feats_lengths)
  249. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  250. if self.normalize is not None:
  251. feats, feats_lengths = self.normalize(feats, feats_lengths)
  252. # Pre-encoder, e.g. used for raw input data
  253. if self.preencoder is not None:
  254. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  255. # 4. Forward encoder
  256. # feats: (Batch, Length, Dim)
  257. # -> encoder_out: (Batch, Length2, Dim2)
  258. if self.encoder.interctc_use_conditioning:
  259. encoder_out, encoder_out_lens, _ = self.encoder(
  260. feats, feats_lengths, ctc=self.ctc
  261. )
  262. else:
  263. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  264. intermediate_outs = None
  265. if isinstance(encoder_out, tuple):
  266. intermediate_outs = encoder_out[1]
  267. encoder_out = encoder_out[0]
  268. # Post-encoder, e.g. NLU
  269. if self.postencoder is not None:
  270. encoder_out, encoder_out_lens = self.postencoder(
  271. encoder_out, encoder_out_lens
  272. )
  273. assert encoder_out.size(0) == speech.size(0), (
  274. encoder_out.size(),
  275. speech.size(0),
  276. )
  277. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  278. encoder_out.size(),
  279. encoder_out_lens.max(),
  280. )
  281. if intermediate_outs is not None:
  282. return (encoder_out, intermediate_outs), encoder_out_lens
  283. return encoder_out, encoder_out_lens
  284. def encode_chunk(
  285. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
  286. ) -> Tuple[torch.Tensor, torch.Tensor]:
  287. """Frontend + Encoder. Note that this method is used by asr_inference.py
  288. Args:
  289. speech: (Batch, Length, ...)
  290. speech_lengths: (Batch, )
  291. """
  292. with autocast(False):
  293. # 1. Extract feats
  294. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  295. # 2. Data augmentation
  296. if self.specaug is not None and self.training:
  297. feats, feats_lengths = self.specaug(feats, feats_lengths)
  298. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  299. if self.normalize is not None:
  300. feats, feats_lengths = self.normalize(feats, feats_lengths)
  301. # Pre-encoder, e.g. used for raw input data
  302. if self.preencoder is not None:
  303. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  304. # 4. Forward encoder
  305. # feats: (Batch, Length, Dim)
  306. # -> encoder_out: (Batch, Length2, Dim2)
  307. if self.encoder.interctc_use_conditioning:
  308. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
  309. feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
  310. )
  311. else:
  312. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
  313. intermediate_outs = None
  314. if isinstance(encoder_out, tuple):
  315. intermediate_outs = encoder_out[1]
  316. encoder_out = encoder_out[0]
  317. # Post-encoder, e.g. NLU
  318. if self.postencoder is not None:
  319. encoder_out, encoder_out_lens = self.postencoder(
  320. encoder_out, encoder_out_lens
  321. )
  322. assert encoder_out.size(0) == speech.size(0), (
  323. encoder_out.size(),
  324. speech.size(0),
  325. )
  326. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  327. encoder_out.size(),
  328. encoder_out_lens.max(),
  329. )
  330. if intermediate_outs is not None:
  331. return (encoder_out, intermediate_outs), encoder_out_lens
  332. return encoder_out, encoder_out_lens
  333. def calc_predictor(self, encoder_out, encoder_out_lens):
  334. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  335. encoder_out.device)
  336. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
  337. ignore_id=self.ignore_id)
  338. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  339. def calc_predictor_chunk(self, encoder_out, cache=None):
  340. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
  341. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  342. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  343. decoder_outs = self.decoder(
  344. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  345. )
  346. decoder_out = decoder_outs[0]
  347. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  348. return decoder_out, ys_pad_lens
  349. def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
  350. decoder_outs = self.decoder.forward_chunk(
  351. encoder_out, sematic_embeds, cache["decoder"]
  352. )
  353. decoder_out = decoder_outs
  354. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  355. return decoder_out
  356. def _extract_feats(
  357. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  358. ) -> Tuple[torch.Tensor, torch.Tensor]:
  359. assert speech_lengths.dim() == 1, speech_lengths.shape
  360. # for data-parallel
  361. speech = speech[:, : speech_lengths.max()]
  362. if self.frontend is not None:
  363. # Frontend
  364. # e.g. STFT and Feature extract
  365. # data_loader may send time-domain signal in this case
  366. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  367. feats, feats_lengths = self.frontend(speech, speech_lengths)
  368. else:
  369. # No frontend and no feature extract
  370. feats, feats_lengths = speech, speech_lengths
  371. return feats, feats_lengths
  372. def nll(
  373. self,
  374. encoder_out: torch.Tensor,
  375. encoder_out_lens: torch.Tensor,
  376. ys_pad: torch.Tensor,
  377. ys_pad_lens: torch.Tensor,
  378. ) -> torch.Tensor:
  379. """Compute negative log likelihood(nll) from transformer-decoder
  380. Normally, this function is called in batchify_nll.
  381. Args:
  382. encoder_out: (Batch, Length, Dim)
  383. encoder_out_lens: (Batch,)
  384. ys_pad: (Batch, Length)
  385. ys_pad_lens: (Batch,)
  386. """
  387. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  388. ys_in_lens = ys_pad_lens + 1
  389. # 1. Forward decoder
  390. decoder_out, _ = self.decoder(
  391. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  392. ) # [batch, seqlen, dim]
  393. batch_size = decoder_out.size(0)
  394. decoder_num_class = decoder_out.size(2)
  395. # nll: negative log-likelihood
  396. nll = torch.nn.functional.cross_entropy(
  397. decoder_out.view(-1, decoder_num_class),
  398. ys_out_pad.view(-1),
  399. ignore_index=self.ignore_id,
  400. reduction="none",
  401. )
  402. nll = nll.view(batch_size, -1)
  403. nll = nll.sum(dim=1)
  404. assert nll.size(0) == batch_size
  405. return nll
  406. def batchify_nll(
  407. self,
  408. encoder_out: torch.Tensor,
  409. encoder_out_lens: torch.Tensor,
  410. ys_pad: torch.Tensor,
  411. ys_pad_lens: torch.Tensor,
  412. batch_size: int = 100,
  413. ):
  414. """Compute negative log likelihood(nll) from transformer-decoder
  415. To avoid OOM, this fuction seperate the input into batches.
  416. Then call nll for each batch and combine and return results.
  417. Args:
  418. encoder_out: (Batch, Length, Dim)
  419. encoder_out_lens: (Batch,)
  420. ys_pad: (Batch, Length)
  421. ys_pad_lens: (Batch,)
  422. batch_size: int, samples each batch contain when computing nll,
  423. you may change this to avoid OOM or increase
  424. GPU memory usage
  425. """
  426. total_num = encoder_out.size(0)
  427. if total_num <= batch_size:
  428. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  429. else:
  430. nll = []
  431. start_idx = 0
  432. while True:
  433. end_idx = min(start_idx + batch_size, total_num)
  434. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  435. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  436. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  437. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  438. batch_nll = self.nll(
  439. batch_encoder_out,
  440. batch_encoder_out_lens,
  441. batch_ys_pad,
  442. batch_ys_pad_lens,
  443. )
  444. nll.append(batch_nll)
  445. start_idx = end_idx
  446. if start_idx == total_num:
  447. break
  448. nll = torch.cat(nll)
  449. assert nll.size(0) == total_num
  450. return nll
  451. def _calc_att_loss(
  452. self,
  453. encoder_out: torch.Tensor,
  454. encoder_out_lens: torch.Tensor,
  455. ys_pad: torch.Tensor,
  456. ys_pad_lens: torch.Tensor,
  457. ):
  458. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  459. encoder_out.device)
  460. if self.predictor_bias == 1:
  461. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  462. ys_pad_lens = ys_pad_lens + self.predictor_bias
  463. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  464. ignore_id=self.ignore_id)
  465. # 0. sampler
  466. decoder_out_1st = None
  467. if self.sampling_ratio > 0.0:
  468. if self.step_cur < 2:
  469. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  470. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  471. pre_acoustic_embeds)
  472. else:
  473. if self.step_cur < 2:
  474. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  475. sematic_embeds = pre_acoustic_embeds
  476. # 1. Forward decoder
  477. decoder_outs = self.decoder(
  478. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  479. )
  480. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  481. if decoder_out_1st is None:
  482. decoder_out_1st = decoder_out
  483. # 2. Compute attention loss
  484. loss_att = self.criterion_att(decoder_out, ys_pad)
  485. acc_att = th_accuracy(
  486. decoder_out_1st.view(-1, self.vocab_size),
  487. ys_pad,
  488. ignore_label=self.ignore_id,
  489. )
  490. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  491. # Compute cer/wer using attention-decoder
  492. if self.training or self.error_calculator is None:
  493. cer_att, wer_att = None, None
  494. else:
  495. ys_hat = decoder_out_1st.argmax(dim=-1)
  496. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  497. return loss_att, acc_att, cer_att, wer_att, loss_pre
  498. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  499. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  500. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  501. if self.share_embedding:
  502. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  503. else:
  504. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  505. with torch.no_grad():
  506. decoder_outs = self.decoder(
  507. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  508. )
  509. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  510. pred_tokens = decoder_out.argmax(-1)
  511. nonpad_positions = ys_pad.ne(self.ignore_id)
  512. seq_lens = (nonpad_positions).sum(1)
  513. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  514. input_mask = torch.ones_like(nonpad_positions)
  515. bsz, seq_len = ys_pad.size()
  516. for li in range(bsz):
  517. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  518. if target_num > 0:
  519. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  520. input_mask = input_mask.eq(1)
  521. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  522. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  523. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  524. input_mask_expand_dim, 0)
  525. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  526. def _calc_ctc_loss(
  527. self,
  528. encoder_out: torch.Tensor,
  529. encoder_out_lens: torch.Tensor,
  530. ys_pad: torch.Tensor,
  531. ys_pad_lens: torch.Tensor,
  532. ):
  533. # Calc CTC loss
  534. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  535. # Calc CER using CTC
  536. cer_ctc = None
  537. if not self.training and self.error_calculator is not None:
  538. ys_hat = self.ctc.argmax(encoder_out).data
  539. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  540. return loss_ctc, cer_ctc
  541. class ParaformerBert(Paraformer):
  542. """
  543. Author: Speech Lab, Alibaba Group, China
  544. Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
  545. """
  546. def __init__(
  547. self,
  548. vocab_size: int,
  549. token_list: Union[Tuple[str, ...], List[str]],
  550. frontend: Optional[AbsFrontend],
  551. specaug: Optional[AbsSpecAug],
  552. normalize: Optional[AbsNormalize],
  553. preencoder: Optional[AbsPreEncoder],
  554. encoder: AbsEncoder,
  555. postencoder: Optional[AbsPostEncoder],
  556. decoder: AbsDecoder,
  557. ctc: CTC,
  558. ctc_weight: float = 0.5,
  559. interctc_weight: float = 0.0,
  560. ignore_id: int = -1,
  561. blank_id: int = 0,
  562. sos: int = 1,
  563. eos: int = 2,
  564. lsm_weight: float = 0.0,
  565. length_normalized_loss: bool = False,
  566. report_cer: bool = True,
  567. report_wer: bool = True,
  568. sym_space: str = "<space>",
  569. sym_blank: str = "<blank>",
  570. extract_feats_in_collect_stats: bool = True,
  571. predictor=None,
  572. predictor_weight: float = 0.0,
  573. predictor_bias: int = 0,
  574. sampling_ratio: float = 0.2,
  575. embeds_id: int = 2,
  576. embeds_loss_weight: float = 0.0,
  577. embed_dims: int = 768,
  578. ):
  579. assert check_argument_types()
  580. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  581. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  582. super().__init__(
  583. vocab_size=vocab_size,
  584. token_list=token_list,
  585. frontend=frontend,
  586. specaug=specaug,
  587. normalize=normalize,
  588. preencoder=preencoder,
  589. encoder=encoder,
  590. postencoder=postencoder,
  591. decoder=decoder,
  592. ctc=ctc,
  593. ctc_weight=ctc_weight,
  594. interctc_weight=interctc_weight,
  595. ignore_id=ignore_id,
  596. blank_id=blank_id,
  597. sos=sos,
  598. eos=eos,
  599. lsm_weight=lsm_weight,
  600. length_normalized_loss=length_normalized_loss,
  601. report_cer=report_cer,
  602. report_wer=report_wer,
  603. sym_space=sym_space,
  604. sym_blank=sym_blank,
  605. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  606. predictor=predictor,
  607. predictor_weight=predictor_weight,
  608. predictor_bias=predictor_bias,
  609. sampling_ratio=sampling_ratio,
  610. )
  611. self.decoder.embeds_id = embeds_id
  612. decoder_attention_dim = self.decoder.attention_dim
  613. self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
  614. self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
  615. self.embeds_loss_weight = embeds_loss_weight
  616. self.length_normalized_loss = length_normalized_loss
  617. def _calc_embed_loss(self,
  618. ys_pad: torch.Tensor,
  619. ys_pad_lens: torch.Tensor,
  620. embed: torch.Tensor = None,
  621. embed_lengths: torch.Tensor = None,
  622. embeds_outputs: torch.Tensor = None,
  623. ):
  624. embeds_outputs = self.pro_nn(embeds_outputs)
  625. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  626. embeds_outputs *= tgt_mask # b x l x d
  627. embed *= tgt_mask # b x l x d
  628. cos_loss = 1.0 - self.cos(embeds_outputs, embed)
  629. cos_loss *= tgt_mask.squeeze(2)
  630. if self.length_normalized_loss:
  631. token_num_total = torch.sum(tgt_mask)
  632. else:
  633. token_num_total = tgt_mask.size()[0]
  634. cos_loss_total = torch.sum(cos_loss)
  635. cos_loss = cos_loss_total / token_num_total
  636. # print("cos_loss: {}".format(cos_loss))
  637. return cos_loss
  638. def _calc_att_loss(
  639. self,
  640. encoder_out: torch.Tensor,
  641. encoder_out_lens: torch.Tensor,
  642. ys_pad: torch.Tensor,
  643. ys_pad_lens: torch.Tensor,
  644. ):
  645. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  646. encoder_out.device)
  647. if self.predictor_bias == 1:
  648. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  649. ys_pad_lens = ys_pad_lens + self.predictor_bias
  650. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  651. ignore_id=self.ignore_id)
  652. # 0. sampler
  653. decoder_out_1st = None
  654. if self.sampling_ratio > 0.0:
  655. if self.step_cur < 2:
  656. logging.info(
  657. "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  658. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  659. pre_acoustic_embeds)
  660. else:
  661. if self.step_cur < 2:
  662. logging.info(
  663. "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  664. sematic_embeds = pre_acoustic_embeds
  665. # 1. Forward decoder
  666. decoder_outs = self.decoder(
  667. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  668. )
  669. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  670. embeds_outputs = None
  671. if len(decoder_outs) > 2:
  672. embeds_outputs = decoder_outs[2]
  673. if decoder_out_1st is None:
  674. decoder_out_1st = decoder_out
  675. # 2. Compute attention loss
  676. loss_att = self.criterion_att(decoder_out, ys_pad)
  677. acc_att = th_accuracy(
  678. decoder_out_1st.view(-1, self.vocab_size),
  679. ys_pad,
  680. ignore_label=self.ignore_id,
  681. )
  682. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  683. # Compute cer/wer using attention-decoder
  684. if self.training or self.error_calculator is None:
  685. cer_att, wer_att = None, None
  686. else:
  687. ys_hat = decoder_out_1st.argmax(dim=-1)
  688. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  689. return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
  690. def forward(
  691. self,
  692. speech: torch.Tensor,
  693. speech_lengths: torch.Tensor,
  694. text: torch.Tensor,
  695. text_lengths: torch.Tensor,
  696. embed: torch.Tensor = None,
  697. embed_lengths: torch.Tensor = None,
  698. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  699. """Frontend + Encoder + Decoder + Calc loss
  700. Args:
  701. speech: (Batch, Length, ...)
  702. speech_lengths: (Batch, )
  703. text: (Batch, Length)
  704. text_lengths: (Batch,)
  705. """
  706. assert text_lengths.dim() == 1, text_lengths.shape
  707. # Check that batch_size is unified
  708. assert (
  709. speech.shape[0]
  710. == speech_lengths.shape[0]
  711. == text.shape[0]
  712. == text_lengths.shape[0]
  713. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  714. batch_size = speech.shape[0]
  715. self.step_cur += 1
  716. # for data-parallel
  717. text = text[:, : text_lengths.max()]
  718. speech = speech[:, :speech_lengths.max(), :]
  719. if embed is not None:
  720. embed = embed[:, :embed_lengths.max(), :]
  721. # 1. Encoder
  722. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  723. intermediate_outs = None
  724. if isinstance(encoder_out, tuple):
  725. intermediate_outs = encoder_out[1]
  726. encoder_out = encoder_out[0]
  727. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  728. loss_ctc, cer_ctc = None, None
  729. loss_pre = 0.0
  730. cos_loss = 0.0
  731. stats = dict()
  732. # 1. CTC branch
  733. if self.ctc_weight != 0.0:
  734. loss_ctc, cer_ctc = self._calc_ctc_loss(
  735. encoder_out, encoder_out_lens, text, text_lengths
  736. )
  737. # Collect CTC branch stats
  738. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  739. stats["cer_ctc"] = cer_ctc
  740. # Intermediate CTC (optional)
  741. loss_interctc = 0.0
  742. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  743. for layer_idx, intermediate_out in intermediate_outs:
  744. # we assume intermediate_out has the same length & padding
  745. # as those of encoder_out
  746. loss_ic, cer_ic = self._calc_ctc_loss(
  747. intermediate_out, encoder_out_lens, text, text_lengths
  748. )
  749. loss_interctc = loss_interctc + loss_ic
  750. # Collect Intermedaite CTC stats
  751. stats["loss_interctc_layer{}".format(layer_idx)] = (
  752. loss_ic.detach() if loss_ic is not None else None
  753. )
  754. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  755. loss_interctc = loss_interctc / len(intermediate_outs)
  756. # calculate whole encoder loss
  757. loss_ctc = (
  758. 1 - self.interctc_weight
  759. ) * loss_ctc + self.interctc_weight * loss_interctc
  760. # 2b. Attention decoder branch
  761. if self.ctc_weight != 1.0:
  762. loss_ret = self._calc_att_loss(
  763. encoder_out, encoder_out_lens, text, text_lengths
  764. )
  765. loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
  766. loss_ret[4]
  767. embeds_outputs = None
  768. if len(loss_ret) > 5:
  769. embeds_outputs = loss_ret[5]
  770. if embeds_outputs is not None:
  771. cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
  772. # 3. CTC-Att loss definition
  773. if self.ctc_weight == 0.0:
  774. loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  775. elif self.ctc_weight == 1.0:
  776. loss = loss_ctc
  777. else:
  778. loss = self.ctc_weight * loss_ctc + (
  779. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  780. # Collect Attn branch stats
  781. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  782. stats["acc"] = acc_att
  783. stats["cer"] = cer_att
  784. stats["wer"] = wer_att
  785. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
  786. stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
  787. stats["loss"] = torch.clone(loss.detach())
  788. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  789. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  790. return loss, stats, weight
  791. class BiCifParaformer(Paraformer):
  792. """
  793. Paraformer model with an extra cif predictor
  794. to conduct accurate timestamp prediction
  795. """
  796. def __init__(
  797. self,
  798. vocab_size: int,
  799. token_list: Union[Tuple[str, ...], List[str]],
  800. frontend: Optional[AbsFrontend],
  801. specaug: Optional[AbsSpecAug],
  802. normalize: Optional[AbsNormalize],
  803. preencoder: Optional[AbsPreEncoder],
  804. encoder: AbsEncoder,
  805. postencoder: Optional[AbsPostEncoder],
  806. decoder: AbsDecoder,
  807. ctc: CTC,
  808. ctc_weight: float = 0.5,
  809. interctc_weight: float = 0.0,
  810. ignore_id: int = -1,
  811. blank_id: int = 0,
  812. sos: int = 1,
  813. eos: int = 2,
  814. lsm_weight: float = 0.0,
  815. length_normalized_loss: bool = False,
  816. report_cer: bool = True,
  817. report_wer: bool = True,
  818. sym_space: str = "<space>",
  819. sym_blank: str = "<blank>",
  820. extract_feats_in_collect_stats: bool = True,
  821. predictor = None,
  822. predictor_weight: float = 0.0,
  823. predictor_bias: int = 0,
  824. sampling_ratio: float = 0.2,
  825. ):
  826. assert check_argument_types()
  827. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  828. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  829. super().__init__(
  830. vocab_size=vocab_size,
  831. token_list=token_list,
  832. frontend=frontend,
  833. specaug=specaug,
  834. normalize=normalize,
  835. preencoder=preencoder,
  836. encoder=encoder,
  837. postencoder=postencoder,
  838. decoder=decoder,
  839. ctc=ctc,
  840. ctc_weight=ctc_weight,
  841. interctc_weight=interctc_weight,
  842. ignore_id=ignore_id,
  843. blank_id=blank_id,
  844. sos=sos,
  845. eos=eos,
  846. lsm_weight=lsm_weight,
  847. length_normalized_loss=length_normalized_loss,
  848. report_cer=report_cer,
  849. report_wer=report_wer,
  850. sym_space=sym_space,
  851. sym_blank=sym_blank,
  852. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  853. predictor=predictor,
  854. predictor_weight=predictor_weight,
  855. predictor_bias=predictor_bias,
  856. sampling_ratio=sampling_ratio,
  857. )
  858. assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
  859. def _calc_pre2_loss(
  860. self,
  861. encoder_out: torch.Tensor,
  862. encoder_out_lens: torch.Tensor,
  863. ys_pad: torch.Tensor,
  864. ys_pad_lens: torch.Tensor,
  865. ):
  866. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  867. encoder_out.device)
  868. if self.predictor_bias == 1:
  869. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  870. ys_pad_lens = ys_pad_lens + self.predictor_bias
  871. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  872. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  873. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  874. return loss_pre2
  875. def calc_predictor(self, encoder_out, encoder_out_lens):
  876. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  877. encoder_out.device)
  878. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
  879. ignore_id=self.ignore_id)
  880. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  881. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  882. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  883. encoder_out.device)
  884. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  885. encoder_out_mask,
  886. token_num)
  887. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  888. def forward(
  889. self,
  890. speech: torch.Tensor,
  891. speech_lengths: torch.Tensor,
  892. text: torch.Tensor,
  893. text_lengths: torch.Tensor,
  894. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  895. """Frontend + Encoder + Decoder + Calc loss
  896. Args:
  897. speech: (Batch, Length, ...)
  898. speech_lengths: (Batch, )
  899. text: (Batch, Length)
  900. text_lengths: (Batch,)
  901. """
  902. assert text_lengths.dim() == 1, text_lengths.shape
  903. # Check that batch_size is unified
  904. assert (
  905. speech.shape[0]
  906. == speech_lengths.shape[0]
  907. == text.shape[0]
  908. == text_lengths.shape[0]
  909. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  910. batch_size = speech.shape[0]
  911. self.step_cur += 1
  912. # for data-parallel
  913. text = text[:, : text_lengths.max()]
  914. speech = speech[:, :speech_lengths.max()]
  915. # 1. Encoder
  916. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  917. stats = dict()
  918. loss_pre2 = self._calc_pre2_loss(
  919. encoder_out, encoder_out_lens, text, text_lengths
  920. )
  921. loss = loss_pre2
  922. stats["loss_pre2"] = loss_pre2.detach().cpu()
  923. stats["loss"] = torch.clone(loss.detach())
  924. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  925. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  926. return loss, stats, weight
  927. class ContextualParaformer(Paraformer):
  928. """
  929. Paraformer model with contextual hotword
  930. """
  931. def __init__(
  932. self,
  933. vocab_size: int,
  934. token_list: Union[Tuple[str, ...], List[str]],
  935. frontend: Optional[AbsFrontend],
  936. specaug: Optional[AbsSpecAug],
  937. normalize: Optional[AbsNormalize],
  938. preencoder: Optional[AbsPreEncoder],
  939. encoder: AbsEncoder,
  940. postencoder: Optional[AbsPostEncoder],
  941. decoder: AbsDecoder,
  942. ctc: CTC,
  943. ctc_weight: float = 0.5,
  944. interctc_weight: float = 0.0,
  945. ignore_id: int = -1,
  946. blank_id: int = 0,
  947. sos: int = 1,
  948. eos: int = 2,
  949. lsm_weight: float = 0.0,
  950. length_normalized_loss: bool = False,
  951. report_cer: bool = True,
  952. report_wer: bool = True,
  953. sym_space: str = "<space>",
  954. sym_blank: str = "<blank>",
  955. extract_feats_in_collect_stats: bool = True,
  956. predictor=None,
  957. predictor_weight: float = 0.0,
  958. predictor_bias: int = 0,
  959. sampling_ratio: float = 0.2,
  960. min_hw_length: int = 2,
  961. max_hw_length: int = 4,
  962. sample_rate: float = 0.6,
  963. batch_rate: float = 0.5,
  964. double_rate: float = -1.0,
  965. target_buffer_length: int = -1,
  966. inner_dim: int = 256,
  967. bias_encoder_type: str = 'lstm',
  968. label_bracket: bool = False,
  969. ):
  970. assert check_argument_types()
  971. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  972. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  973. super().__init__(
  974. vocab_size=vocab_size,
  975. token_list=token_list,
  976. frontend=frontend,
  977. specaug=specaug,
  978. normalize=normalize,
  979. preencoder=preencoder,
  980. encoder=encoder,
  981. postencoder=postencoder,
  982. decoder=decoder,
  983. ctc=ctc,
  984. ctc_weight=ctc_weight,
  985. interctc_weight=interctc_weight,
  986. ignore_id=ignore_id,
  987. blank_id=blank_id,
  988. sos=sos,
  989. eos=eos,
  990. lsm_weight=lsm_weight,
  991. length_normalized_loss=length_normalized_loss,
  992. report_cer=report_cer,
  993. report_wer=report_wer,
  994. sym_space=sym_space,
  995. sym_blank=sym_blank,
  996. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  997. predictor=predictor,
  998. predictor_weight=predictor_weight,
  999. predictor_bias=predictor_bias,
  1000. sampling_ratio=sampling_ratio,
  1001. )
  1002. if bias_encoder_type == 'lstm':
  1003. logging.warning("enable bias encoder sampling and contextual training")
  1004. self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
  1005. self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
  1006. else:
  1007. logging.error("Unsupport bias encoder type")
  1008. self.min_hw_length = min_hw_length
  1009. self.max_hw_length = max_hw_length
  1010. self.sample_rate = sample_rate
  1011. self.batch_rate = batch_rate
  1012. self.target_buffer_length = target_buffer_length
  1013. self.double_rate = double_rate
  1014. if self.target_buffer_length > 0:
  1015. self.hotword_buffer = None
  1016. self.length_record = []
  1017. self.current_buffer_length = 0
  1018. def forward(
  1019. self,
  1020. speech: torch.Tensor,
  1021. speech_lengths: torch.Tensor,
  1022. text: torch.Tensor,
  1023. text_lengths: torch.Tensor,
  1024. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1025. """Frontend + Encoder + Decoder + Calc loss
  1026. Args:
  1027. speech: (Batch, Length, ...)
  1028. speech_lengths: (Batch, )
  1029. text: (Batch, Length)
  1030. text_lengths: (Batch,)
  1031. """
  1032. assert text_lengths.dim() == 1, text_lengths.shape
  1033. # Check that batch_size is unified
  1034. assert (
  1035. speech.shape[0]
  1036. == speech_lengths.shape[0]
  1037. == text.shape[0]
  1038. == text_lengths.shape[0]
  1039. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1040. batch_size = speech.shape[0]
  1041. self.step_cur += 1
  1042. # for data-parallel
  1043. text = text[:, : text_lengths.max()]
  1044. speech = speech[:, :speech_lengths.max()]
  1045. # 1. Encoder
  1046. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1047. intermediate_outs = None
  1048. if isinstance(encoder_out, tuple):
  1049. intermediate_outs = encoder_out[1]
  1050. encoder_out = encoder_out[0]
  1051. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1052. loss_ctc, cer_ctc = None, None
  1053. loss_pre = None
  1054. stats = dict()
  1055. # 1. CTC branch
  1056. if self.ctc_weight != 0.0:
  1057. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1058. encoder_out, encoder_out_lens, text, text_lengths
  1059. )
  1060. # Collect CTC branch stats
  1061. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1062. stats["cer_ctc"] = cer_ctc
  1063. # Intermediate CTC (optional)
  1064. loss_interctc = 0.0
  1065. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1066. for layer_idx, intermediate_out in intermediate_outs:
  1067. # we assume intermediate_out has the same length & padding
  1068. # as those of encoder_out
  1069. loss_ic, cer_ic = self._calc_ctc_loss(
  1070. intermediate_out, encoder_out_lens, text, text_lengths
  1071. )
  1072. loss_interctc = loss_interctc + loss_ic
  1073. # Collect Intermedaite CTC stats
  1074. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1075. loss_ic.detach() if loss_ic is not None else None
  1076. )
  1077. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1078. loss_interctc = loss_interctc / len(intermediate_outs)
  1079. # calculate whole encoder loss
  1080. loss_ctc = (
  1081. 1 - self.interctc_weight
  1082. ) * loss_ctc + self.interctc_weight * loss_interctc
  1083. # 2b. Attention decoder branch
  1084. if self.ctc_weight != 1.0:
  1085. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  1086. encoder_out, encoder_out_lens, text, text_lengths
  1087. )
  1088. # 3. CTC-Att loss definition
  1089. if self.ctc_weight == 0.0:
  1090. loss = loss_att + loss_pre * self.predictor_weight
  1091. elif self.ctc_weight == 1.0:
  1092. loss = loss_ctc
  1093. else:
  1094. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  1095. # Collect Attn branch stats
  1096. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1097. stats["acc"] = acc_att
  1098. stats["cer"] = cer_att
  1099. stats["wer"] = wer_att
  1100. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  1101. stats["loss"] = torch.clone(loss.detach())
  1102. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1103. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1104. return loss, stats, weight
  1105. def _sample_hot_word(self, ys_pad, ys_pad_lens):
  1106. hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
  1107. hw_lengths = [0] # this length is actually for indice, so -1
  1108. for i, length in enumerate(ys_pad_lens):
  1109. if length < 2:
  1110. continue
  1111. if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
  1112. # sample double hotword
  1113. _max_hw_length = min(self.max_hw_length, length // 2)
  1114. # first hotword
  1115. start1 = random.randint(0, length // 3)
  1116. end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
  1117. hw_tokens1 = ys_pad[i][start1:end1 + 1]
  1118. hw_lengths.append(len(hw_tokens1) - 1)
  1119. hw_list.append(hw_tokens1)
  1120. # second hotword
  1121. start2 = random.randint(end1 + 1, length - self.min_hw_length)
  1122. end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
  1123. min(length - 1, start2 + self.max_hw_length - 1))
  1124. hw_tokens2 = ys_pad[i][start2:end2 + 1]
  1125. hw_lengths.append(len(hw_tokens2) - 1)
  1126. hw_list.append(hw_tokens2)
  1127. continue
  1128. if random.random() < self.sample_rate:
  1129. if length == 2:
  1130. hw_tokens = ys_pad[i][:2]
  1131. hw_lengths.append(1)
  1132. hw_list.append(hw_tokens)
  1133. else:
  1134. start = random.randint(0, length - self.min_hw_length)
  1135. end = random.randint(min(length - 1, start + self.min_hw_length - 1),
  1136. min(length - 1, start + self.max_hw_length - 1)) + 1
  1137. # print(start, end)
  1138. hw_tokens = ys_pad[i][start:end]
  1139. hw_lengths.append(len(hw_tokens) - 1)
  1140. hw_list.append(hw_tokens)
  1141. # padding
  1142. hw_list_pad = pad_list(hw_list, 0)
  1143. hw_embed = self.decoder.embed(hw_list_pad)
  1144. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  1145. _ind = np.arange(0, len(hw_list)).tolist()
  1146. # update self.hotword_buffer, throw a part if oversize
  1147. selected = hw_embed[_ind, hw_lengths]
  1148. if self.target_buffer_length > 0:
  1149. _b = selected.shape[0]
  1150. if self.hotword_buffer is None:
  1151. self.hotword_buffer = selected
  1152. self.length_record.append(selected.shape[0])
  1153. self.current_buffer_length = _b
  1154. elif self.current_buffer_length + _b < self.target_buffer_length:
  1155. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1156. self.current_buffer_length += _b
  1157. selected = self.hotword_buffer
  1158. else:
  1159. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1160. random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
  1161. self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
  1162. selected = self.hotword_buffer
  1163. self.current_buffer_length = selected.shape[0]
  1164. return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
  1165. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
  1166. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  1167. ys_pad = ys_pad * tgt_mask[:, :, 0]
  1168. if self.share_embedding:
  1169. ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
  1170. else:
  1171. ys_pad_embed = self.decoder.embed(ys_pad)
  1172. with torch.no_grad():
  1173. decoder_outs = self.decoder(
  1174. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
  1175. )
  1176. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1177. pred_tokens = decoder_out.argmax(-1)
  1178. nonpad_positions = ys_pad.ne(self.ignore_id)
  1179. seq_lens = (nonpad_positions).sum(1)
  1180. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  1181. input_mask = torch.ones_like(nonpad_positions)
  1182. bsz, seq_len = ys_pad.size()
  1183. for li in range(bsz):
  1184. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  1185. if target_num > 0:
  1186. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  1187. input_mask = input_mask.eq(1)
  1188. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  1189. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  1190. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  1191. input_mask_expand_dim, 0)
  1192. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  1193. def _calc_att_loss(
  1194. self,
  1195. encoder_out: torch.Tensor,
  1196. encoder_out_lens: torch.Tensor,
  1197. ys_pad: torch.Tensor,
  1198. ys_pad_lens: torch.Tensor,
  1199. ):
  1200. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1201. encoder_out.device)
  1202. if self.predictor_bias == 1:
  1203. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1204. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1205. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
  1206. encoder_out_mask,
  1207. ignore_id=self.ignore_id)
  1208. # sample hot word
  1209. contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
  1210. # 0. sampler
  1211. decoder_out_1st = None
  1212. if self.sampling_ratio > 0.0:
  1213. if self.step_cur < 2:
  1214. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1215. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  1216. pre_acoustic_embeds, contextual_info)
  1217. else:
  1218. if self.step_cur < 2:
  1219. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1220. sematic_embeds = pre_acoustic_embeds
  1221. # 1. Forward decoder
  1222. decoder_outs = self.decoder(
  1223. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1224. )
  1225. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1226. if decoder_out_1st is None:
  1227. decoder_out_1st = decoder_out
  1228. # 2. Compute attention loss
  1229. loss_att = self.criterion_att(decoder_out, ys_pad)
  1230. acc_att = th_accuracy(
  1231. decoder_out_1st.view(-1, self.vocab_size),
  1232. ys_pad,
  1233. ignore_label=self.ignore_id,
  1234. )
  1235. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1236. # Compute cer/wer using attention-decoder
  1237. if self.training or self.error_calculator is None:
  1238. cer_att, wer_att = None, None
  1239. else:
  1240. ys_hat = decoder_out_1st.argmax(dim=-1)
  1241. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1242. return loss_att, acc_att, cer_att, wer_att, loss_pre
  1243. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
  1244. if hw_list is None:
  1245. # default hotword list
  1246. hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
  1247. hw_list_pad = pad_list(hw_list, 0)
  1248. hw_embed = self.bias_embed(hw_list_pad)
  1249. _, (h_n, _) = self.bias_encoder(hw_embed)
  1250. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1251. else:
  1252. hw_lengths = [len(i) for i in hw_list]
  1253. hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
  1254. hw_embed = self.bias_embed(hw_list_pad)
  1255. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
  1256. enforce_sorted=False)
  1257. _, (h_n, _) = self.bias_encoder(hw_embed)
  1258. # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
  1259. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1260. decoder_outs = self.decoder(
  1261. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1262. )
  1263. decoder_out = decoder_outs[0]
  1264. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  1265. return decoder_out, ys_pad_lens
  1266. def gen_clas_tf2torch_map_dict(self):
  1267. tensor_name_prefix_torch = "bias_encoder"
  1268. tensor_name_prefix_tf = "seq2seq/clas_charrnn"
  1269. tensor_name_prefix_torch_emb = "bias_embed"
  1270. tensor_name_prefix_tf_emb = "seq2seq"
  1271. map_dict_local = {
  1272. # in lstm
  1273. "{}.weight_ih_l0".format(tensor_name_prefix_torch):
  1274. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1275. "squeeze": None,
  1276. "transpose": (1, 0),
  1277. "slice": (0, 512),
  1278. "unit_k": 512,
  1279. }, # (1024, 2048),(2048,512)
  1280. "{}.weight_hh_l0".format(tensor_name_prefix_torch):
  1281. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1282. "squeeze": None,
  1283. "transpose": (1, 0),
  1284. "slice": (512, 1024),
  1285. "unit_k": 512,
  1286. }, # (1024, 2048),(2048,512)
  1287. "{}.bias_ih_l0".format(tensor_name_prefix_torch):
  1288. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1289. "squeeze": None,
  1290. "transpose": None,
  1291. "scale": 0.5,
  1292. "unit_b": 512,
  1293. }, # (2048,),(2048,)
  1294. "{}.bias_hh_l0".format(tensor_name_prefix_torch):
  1295. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1296. "squeeze": None,
  1297. "transpose": None,
  1298. "scale": 0.5,
  1299. "unit_b": 512,
  1300. }, # (2048,),(2048,)
  1301. # in embed
  1302. "{}.weight".format(tensor_name_prefix_torch_emb):
  1303. {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
  1304. "squeeze": None,
  1305. "transpose": None,
  1306. }, # (4235,256),(4235,256)
  1307. }
  1308. return map_dict_local
  1309. def clas_convert_tf2torch(self,
  1310. var_dict_tf,
  1311. var_dict_torch):
  1312. map_dict = self.gen_clas_tf2torch_map_dict()
  1313. var_dict_torch_update = dict()
  1314. for name in sorted(var_dict_torch.keys(), reverse=False):
  1315. names = name.split('.')
  1316. if names[0] == "bias_encoder":
  1317. name_q = name
  1318. if name_q in map_dict.keys():
  1319. name_v = map_dict[name_q]["name"]
  1320. name_tf = name_v
  1321. data_tf = var_dict_tf[name_tf]
  1322. if map_dict[name_q].get("unit_k") is not None:
  1323. dim = map_dict[name_q]["unit_k"]
  1324. i = data_tf[:, 0:dim].copy()
  1325. f = data_tf[:, dim:2 * dim].copy()
  1326. o = data_tf[:, 2 * dim:3 * dim].copy()
  1327. g = data_tf[:, 3 * dim:4 * dim].copy()
  1328. data_tf = np.concatenate([i, o, f, g], axis=1)
  1329. if map_dict[name_q].get("unit_b") is not None:
  1330. dim = map_dict[name_q]["unit_b"]
  1331. i = data_tf[0:dim].copy()
  1332. f = data_tf[dim:2 * dim].copy()
  1333. o = data_tf[2 * dim:3 * dim].copy()
  1334. g = data_tf[3 * dim:4 * dim].copy()
  1335. data_tf = np.concatenate([i, o, f, g], axis=0)
  1336. if map_dict[name_q]["squeeze"] is not None:
  1337. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1338. if map_dict[name_q].get("slice") is not None:
  1339. data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
  1340. if map_dict[name_q].get("scale") is not None:
  1341. data_tf = data_tf * map_dict[name_q]["scale"]
  1342. if map_dict[name_q]["transpose"] is not None:
  1343. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1344. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1345. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1346. var_dict_torch[
  1347. name].size(),
  1348. data_tf.size())
  1349. var_dict_torch_update[name] = data_tf
  1350. logging.info(
  1351. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1352. var_dict_tf[name_tf].shape))
  1353. elif names[0] == "bias_embed":
  1354. name_tf = map_dict[name]["name"]
  1355. data_tf = var_dict_tf[name_tf]
  1356. if map_dict[name]["squeeze"] is not None:
  1357. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1358. if map_dict[name]["transpose"] is not None:
  1359. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1360. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1361. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1362. var_dict_torch[
  1363. name].size(),
  1364. data_tf.size())
  1365. var_dict_torch_update[name] = data_tf
  1366. logging.info(
  1367. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1368. var_dict_tf[name_tf].shape))
  1369. return var_dict_torch_update