e2e_asr_paraformer.py 98 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import random
  11. import numpy as np
  12. from funasr.layers.abs_normalize import AbsNormalize
  13. from funasr.losses.label_smoothing_loss import (
  14. LabelSmoothingLoss, # noqa: H301
  15. )
  16. from funasr.models.ctc import CTC
  17. from funasr.models.decoder.abs_decoder import AbsDecoder
  18. from funasr.models.e2e_asr_common import ErrorCalculator
  19. from funasr.models.encoder.abs_encoder import AbsEncoder
  20. from funasr.models.frontend.abs_frontend import AbsFrontend
  21. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  22. from funasr.models.predictor.cif import mae_loss
  23. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  24. from funasr.models.specaug.abs_specaug import AbsSpecAug
  25. from funasr.modules.add_sos_eos import add_sos_eos
  26. from funasr.modules.nets_utils import make_pad_mask, pad_list
  27. from funasr.modules.nets_utils import th_accuracy
  28. from funasr.torch_utils.device_funcs import force_gatherable
  29. from funasr.models.base_model import FunASRModel
  30. from funasr.models.predictor.cif import CifPredictorV3
  31. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  32. from torch.cuda.amp import autocast
  33. else:
  34. # Nothing to do if torch<1.6.0
  35. @contextmanager
  36. def autocast(enabled=True):
  37. yield
  38. class Paraformer(FunASRModel):
  39. """
  40. Author: Speech Lab of DAMO Academy, Alibaba Group
  41. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  42. https://arxiv.org/abs/2206.08317
  43. """
  44. def __init__(
  45. self,
  46. vocab_size: int,
  47. token_list: Union[Tuple[str, ...], List[str]],
  48. frontend: Optional[AbsFrontend],
  49. specaug: Optional[AbsSpecAug],
  50. normalize: Optional[AbsNormalize],
  51. encoder: AbsEncoder,
  52. decoder: AbsDecoder,
  53. ctc: CTC,
  54. ctc_weight: float = 0.5,
  55. interctc_weight: float = 0.0,
  56. ignore_id: int = -1,
  57. blank_id: int = 0,
  58. sos: int = 1,
  59. eos: int = 2,
  60. lsm_weight: float = 0.0,
  61. length_normalized_loss: bool = False,
  62. report_cer: bool = True,
  63. report_wer: bool = True,
  64. sym_space: str = "<space>",
  65. sym_blank: str = "<blank>",
  66. extract_feats_in_collect_stats: bool = True,
  67. predictor=None,
  68. predictor_weight: float = 0.0,
  69. predictor_bias: int = 0,
  70. sampling_ratio: float = 0.2,
  71. share_embedding: bool = False,
  72. preencoder: Optional[AbsPreEncoder] = None,
  73. postencoder: Optional[AbsPostEncoder] = None,
  74. use_1st_decoder_loss: bool = False,
  75. ):
  76. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  77. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  78. super().__init__()
  79. # note that eos is the same as sos (equivalent ID)
  80. self.blank_id = blank_id
  81. self.sos = vocab_size - 1 if sos is None else sos
  82. self.eos = vocab_size - 1 if eos is None else eos
  83. self.vocab_size = vocab_size
  84. self.ignore_id = ignore_id
  85. self.ctc_weight = ctc_weight
  86. self.interctc_weight = interctc_weight
  87. self.token_list = token_list.copy()
  88. self.frontend = frontend
  89. self.specaug = specaug
  90. self.normalize = normalize
  91. self.preencoder = preencoder
  92. self.postencoder = postencoder
  93. self.encoder = encoder
  94. if not hasattr(self.encoder, "interctc_use_conditioning"):
  95. self.encoder.interctc_use_conditioning = False
  96. if self.encoder.interctc_use_conditioning:
  97. self.encoder.conditioning_layer = torch.nn.Linear(
  98. vocab_size, self.encoder.output_size()
  99. )
  100. self.error_calculator = None
  101. if ctc_weight == 1.0:
  102. self.decoder = None
  103. else:
  104. self.decoder = decoder
  105. self.criterion_att = LabelSmoothingLoss(
  106. size=vocab_size,
  107. padding_idx=ignore_id,
  108. smoothing=lsm_weight,
  109. normalize_length=length_normalized_loss,
  110. )
  111. if report_cer or report_wer:
  112. self.error_calculator = ErrorCalculator(
  113. token_list, sym_space, sym_blank, report_cer, report_wer
  114. )
  115. if ctc_weight == 0.0:
  116. self.ctc = None
  117. else:
  118. self.ctc = ctc
  119. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  120. self.predictor = predictor
  121. self.predictor_weight = predictor_weight
  122. self.predictor_bias = predictor_bias
  123. self.sampling_ratio = sampling_ratio
  124. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  125. self.step_cur = 0
  126. self.share_embedding = share_embedding
  127. if self.share_embedding:
  128. self.decoder.embed = None
  129. self.use_1st_decoder_loss = use_1st_decoder_loss
  130. def forward(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. decoding_ind: int = None,
  137. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  138. """Frontend + Encoder + Decoder + Calc loss
  139. Args:
  140. speech: (Batch, Length, ...)
  141. speech_lengths: (Batch, )
  142. text: (Batch, Length)
  143. text_lengths: (Batch,)
  144. decoding_ind: int
  145. """
  146. assert text_lengths.dim() == 1, text_lengths.shape
  147. # Check that batch_size is unified
  148. assert (
  149. speech.shape[0]
  150. == speech_lengths.shape[0]
  151. == text.shape[0]
  152. == text_lengths.shape[0]
  153. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  154. batch_size = speech.shape[0]
  155. self.step_cur += 1
  156. # for data-parallel
  157. text = text[:, : text_lengths.max()]
  158. speech = speech[:, :speech_lengths.max()]
  159. # 1. Encoder
  160. if hasattr(self.encoder, "overlap_chunk_cls"):
  161. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  162. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  163. else:
  164. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  165. intermediate_outs = None
  166. if isinstance(encoder_out, tuple):
  167. intermediate_outs = encoder_out[1]
  168. encoder_out = encoder_out[0]
  169. loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
  170. loss_ctc, cer_ctc = None, None
  171. loss_pre = None
  172. stats = dict()
  173. # 1. CTC branch
  174. if self.ctc_weight != 0.0:
  175. loss_ctc, cer_ctc = self._calc_ctc_loss(
  176. encoder_out, encoder_out_lens, text, text_lengths
  177. )
  178. # Collect CTC branch stats
  179. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  180. stats["cer_ctc"] = cer_ctc
  181. # Intermediate CTC (optional)
  182. loss_interctc = 0.0
  183. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  184. for layer_idx, intermediate_out in intermediate_outs:
  185. # we assume intermediate_out has the same length & padding
  186. # as those of encoder_out
  187. loss_ic, cer_ic = self._calc_ctc_loss(
  188. intermediate_out, encoder_out_lens, text, text_lengths
  189. )
  190. loss_interctc = loss_interctc + loss_ic
  191. # Collect Intermedaite CTC stats
  192. stats["loss_interctc_layer{}".format(layer_idx)] = (
  193. loss_ic.detach() if loss_ic is not None else None
  194. )
  195. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  196. loss_interctc = loss_interctc / len(intermediate_outs)
  197. # calculate whole encoder loss
  198. loss_ctc = (
  199. 1 - self.interctc_weight
  200. ) * loss_ctc + self.interctc_weight * loss_interctc
  201. # 2b. Attention decoder branch
  202. if self.ctc_weight != 1.0:
  203. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
  204. encoder_out, encoder_out_lens, text, text_lengths
  205. )
  206. # 3. CTC-Att loss definition
  207. if self.ctc_weight == 0.0:
  208. loss = loss_att + loss_pre * self.predictor_weight
  209. elif self.ctc_weight == 1.0:
  210. loss = loss_ctc
  211. else:
  212. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  213. if self.use_1st_decoder_loss and pre_loss_att is not None:
  214. loss = loss + (1 - self.ctc_weight) * pre_loss_att
  215. # Collect Attn branch stats
  216. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  217. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  218. stats["acc"] = acc_att
  219. stats["cer"] = cer_att
  220. stats["wer"] = wer_att
  221. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  222. stats["loss"] = torch.clone(loss.detach())
  223. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  224. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  225. return loss, stats, weight
  226. def collect_feats(
  227. self,
  228. speech: torch.Tensor,
  229. speech_lengths: torch.Tensor,
  230. text: torch.Tensor,
  231. text_lengths: torch.Tensor,
  232. ) -> Dict[str, torch.Tensor]:
  233. if self.extract_feats_in_collect_stats:
  234. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  235. else:
  236. # Generate dummy stats if extract_feats_in_collect_stats is False
  237. logging.warning(
  238. "Generating dummy stats for feats and feats_lengths, "
  239. "because encoder_conf.extract_feats_in_collect_stats is "
  240. f"{self.extract_feats_in_collect_stats}"
  241. )
  242. feats, feats_lengths = speech, speech_lengths
  243. return {"feats": feats, "feats_lengths": feats_lengths}
  244. def encode(
  245. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  246. ) -> Tuple[torch.Tensor, torch.Tensor]:
  247. """Frontend + Encoder. Note that this method is used by asr_inference.py
  248. Args:
  249. speech: (Batch, Length, ...)
  250. speech_lengths: (Batch, )
  251. ind: int
  252. """
  253. with autocast(False):
  254. # 1. Extract feats
  255. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  256. # 2. Data augmentation
  257. if self.specaug is not None and self.training:
  258. feats, feats_lengths = self.specaug(feats, feats_lengths)
  259. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  260. if self.normalize is not None:
  261. feats, feats_lengths = self.normalize(feats, feats_lengths)
  262. # Pre-encoder, e.g. used for raw input data
  263. if self.preencoder is not None:
  264. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  265. # 4. Forward encoder
  266. # feats: (Batch, Length, Dim)
  267. # -> encoder_out: (Batch, Length2, Dim2)
  268. if self.encoder.interctc_use_conditioning:
  269. if hasattr(self.encoder, "overlap_chunk_cls"):
  270. encoder_out, encoder_out_lens, _ = self.encoder(
  271. feats, feats_lengths, ctc=self.ctc, ind=ind
  272. )
  273. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  274. encoder_out_lens,
  275. chunk_outs=None)
  276. else:
  277. encoder_out, encoder_out_lens, _ = self.encoder(
  278. feats, feats_lengths, ctc=self.ctc
  279. )
  280. else:
  281. if hasattr(self.encoder, "overlap_chunk_cls"):
  282. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  283. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  284. encoder_out_lens,
  285. chunk_outs=None)
  286. else:
  287. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  288. intermediate_outs = None
  289. if isinstance(encoder_out, tuple):
  290. intermediate_outs = encoder_out[1]
  291. encoder_out = encoder_out[0]
  292. # Post-encoder, e.g. NLU
  293. if self.postencoder is not None:
  294. encoder_out, encoder_out_lens = self.postencoder(
  295. encoder_out, encoder_out_lens
  296. )
  297. assert encoder_out.size(0) == speech.size(0), (
  298. encoder_out.size(),
  299. speech.size(0),
  300. )
  301. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  302. encoder_out.size(),
  303. encoder_out_lens.max(),
  304. )
  305. if intermediate_outs is not None:
  306. return (encoder_out, intermediate_outs), encoder_out_lens
  307. return encoder_out, encoder_out_lens
  308. def calc_predictor(self, encoder_out, encoder_out_lens):
  309. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  310. encoder_out.device)
  311. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
  312. ignore_id=self.ignore_id)
  313. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  314. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  315. decoder_outs = self.decoder(
  316. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  317. )
  318. decoder_out = decoder_outs[0]
  319. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  320. return decoder_out, ys_pad_lens
  321. def _extract_feats(
  322. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  323. ) -> Tuple[torch.Tensor, torch.Tensor]:
  324. assert speech_lengths.dim() == 1, speech_lengths.shape
  325. # for data-parallel
  326. speech = speech[:, : speech_lengths.max()]
  327. if self.frontend is not None:
  328. # Frontend
  329. # e.g. STFT and Feature extract
  330. # data_loader may send time-domain signal in this case
  331. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  332. feats, feats_lengths = self.frontend(speech, speech_lengths)
  333. else:
  334. # No frontend and no feature extract
  335. feats, feats_lengths = speech, speech_lengths
  336. return feats, feats_lengths
  337. def nll(
  338. self,
  339. encoder_out: torch.Tensor,
  340. encoder_out_lens: torch.Tensor,
  341. ys_pad: torch.Tensor,
  342. ys_pad_lens: torch.Tensor,
  343. ) -> torch.Tensor:
  344. """Compute negative log likelihood(nll) from transformer-decoder
  345. Normally, this function is called in batchify_nll.
  346. Args:
  347. encoder_out: (Batch, Length, Dim)
  348. encoder_out_lens: (Batch,)
  349. ys_pad: (Batch, Length)
  350. ys_pad_lens: (Batch,)
  351. """
  352. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  353. ys_in_lens = ys_pad_lens + 1
  354. # 1. Forward decoder
  355. decoder_out, _ = self.decoder(
  356. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  357. ) # [batch, seqlen, dim]
  358. batch_size = decoder_out.size(0)
  359. decoder_num_class = decoder_out.size(2)
  360. # nll: negative log-likelihood
  361. nll = torch.nn.functional.cross_entropy(
  362. decoder_out.view(-1, decoder_num_class),
  363. ys_out_pad.view(-1),
  364. ignore_index=self.ignore_id,
  365. reduction="none",
  366. )
  367. nll = nll.view(batch_size, -1)
  368. nll = nll.sum(dim=1)
  369. assert nll.size(0) == batch_size
  370. return nll
  371. def batchify_nll(
  372. self,
  373. encoder_out: torch.Tensor,
  374. encoder_out_lens: torch.Tensor,
  375. ys_pad: torch.Tensor,
  376. ys_pad_lens: torch.Tensor,
  377. batch_size: int = 100,
  378. ):
  379. """Compute negative log likelihood(nll) from transformer-decoder
  380. To avoid OOM, this fuction seperate the input into batches.
  381. Then call nll for each batch and combine and return results.
  382. Args:
  383. encoder_out: (Batch, Length, Dim)
  384. encoder_out_lens: (Batch,)
  385. ys_pad: (Batch, Length)
  386. ys_pad_lens: (Batch,)
  387. batch_size: int, samples each batch contain when computing nll,
  388. you may change this to avoid OOM or increase
  389. GPU memory usage
  390. """
  391. total_num = encoder_out.size(0)
  392. if total_num <= batch_size:
  393. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  394. else:
  395. nll = []
  396. start_idx = 0
  397. while True:
  398. end_idx = min(start_idx + batch_size, total_num)
  399. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  400. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  401. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  402. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  403. batch_nll = self.nll(
  404. batch_encoder_out,
  405. batch_encoder_out_lens,
  406. batch_ys_pad,
  407. batch_ys_pad_lens,
  408. )
  409. nll.append(batch_nll)
  410. start_idx = end_idx
  411. if start_idx == total_num:
  412. break
  413. nll = torch.cat(nll)
  414. assert nll.size(0) == total_num
  415. return nll
  416. def _calc_att_loss(
  417. self,
  418. encoder_out: torch.Tensor,
  419. encoder_out_lens: torch.Tensor,
  420. ys_pad: torch.Tensor,
  421. ys_pad_lens: torch.Tensor,
  422. ):
  423. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  424. encoder_out.device)
  425. if self.predictor_bias == 1:
  426. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  427. ys_pad_lens = ys_pad_lens + self.predictor_bias
  428. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  429. ignore_id=self.ignore_id)
  430. # 0. sampler
  431. decoder_out_1st = None
  432. pre_loss_att = None
  433. if self.sampling_ratio > 0.0:
  434. if self.step_cur < 2:
  435. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  436. if self.use_1st_decoder_loss:
  437. sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  438. pre_acoustic_embeds)
  439. else:
  440. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  441. pre_acoustic_embeds)
  442. else:
  443. if self.step_cur < 2:
  444. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  445. sematic_embeds = pre_acoustic_embeds
  446. # 1. Forward decoder
  447. decoder_outs = self.decoder(
  448. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  449. )
  450. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  451. if decoder_out_1st is None:
  452. decoder_out_1st = decoder_out
  453. # 2. Compute attention loss
  454. loss_att = self.criterion_att(decoder_out, ys_pad)
  455. acc_att = th_accuracy(
  456. decoder_out_1st.view(-1, self.vocab_size),
  457. ys_pad,
  458. ignore_label=self.ignore_id,
  459. )
  460. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  461. # Compute cer/wer using attention-decoder
  462. if self.training or self.error_calculator is None:
  463. cer_att, wer_att = None, None
  464. else:
  465. ys_hat = decoder_out_1st.argmax(dim=-1)
  466. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  467. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  468. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  469. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  470. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  471. if self.share_embedding:
  472. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  473. else:
  474. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  475. with torch.no_grad():
  476. decoder_outs = self.decoder(
  477. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  478. )
  479. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  480. pred_tokens = decoder_out.argmax(-1)
  481. nonpad_positions = ys_pad.ne(self.ignore_id)
  482. seq_lens = (nonpad_positions).sum(1)
  483. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  484. input_mask = torch.ones_like(nonpad_positions)
  485. bsz, seq_len = ys_pad.size()
  486. for li in range(bsz):
  487. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  488. if target_num > 0:
  489. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  490. input_mask = input_mask.eq(1)
  491. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  492. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  493. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  494. input_mask_expand_dim, 0)
  495. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  496. def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  497. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  498. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  499. if self.share_embedding:
  500. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  501. else:
  502. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  503. decoder_outs = self.decoder(
  504. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  505. )
  506. pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
  507. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  508. pred_tokens = decoder_out.argmax(-1)
  509. nonpad_positions = ys_pad.ne(self.ignore_id)
  510. seq_lens = (nonpad_positions).sum(1)
  511. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  512. input_mask = torch.ones_like(nonpad_positions)
  513. bsz, seq_len = ys_pad.size()
  514. for li in range(bsz):
  515. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  516. if target_num > 0:
  517. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  518. input_mask = input_mask.eq(1)
  519. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  520. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  521. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  522. input_mask_expand_dim, 0)
  523. return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
  524. def _calc_ctc_loss(
  525. self,
  526. encoder_out: torch.Tensor,
  527. encoder_out_lens: torch.Tensor,
  528. ys_pad: torch.Tensor,
  529. ys_pad_lens: torch.Tensor,
  530. ):
  531. # Calc CTC loss
  532. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  533. # Calc CER using CTC
  534. cer_ctc = None
  535. if not self.training and self.error_calculator is not None:
  536. ys_hat = self.ctc.argmax(encoder_out).data
  537. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  538. return loss_ctc, cer_ctc
  539. class ParaformerOnline(Paraformer):
  540. """
  541. Author: Speech Lab, Alibaba Group, China
  542. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  543. https://arxiv.org/abs/2206.08317
  544. """
  545. def __init__(
  546. self,
  547. vocab_size: int,
  548. token_list: Union[Tuple[str, ...], List[str]],
  549. frontend: Optional[AbsFrontend],
  550. specaug: Optional[AbsSpecAug],
  551. normalize: Optional[AbsNormalize],
  552. encoder: AbsEncoder,
  553. decoder: AbsDecoder,
  554. ctc: CTC,
  555. ctc_weight: float = 0.5,
  556. interctc_weight: float = 0.0,
  557. ignore_id: int = -1,
  558. blank_id: int = 0,
  559. sos: int = 1,
  560. eos: int = 2,
  561. lsm_weight: float = 0.0,
  562. length_normalized_loss: bool = False,
  563. report_cer: bool = True,
  564. report_wer: bool = True,
  565. sym_space: str = "<space>",
  566. sym_blank: str = "<blank>",
  567. extract_feats_in_collect_stats: bool = True,
  568. predictor=None,
  569. predictor_weight: float = 0.0,
  570. predictor_bias: int = 0,
  571. sampling_ratio: float = 0.2,
  572. decoder_attention_chunk_type: str = 'chunk',
  573. share_embedding: bool = False,
  574. preencoder: Optional[AbsPreEncoder] = None,
  575. postencoder: Optional[AbsPostEncoder] = None,
  576. use_1st_decoder_loss: bool = False,
  577. ):
  578. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  579. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  580. super().__init__(
  581. vocab_size=vocab_size,
  582. token_list=token_list,
  583. frontend=frontend,
  584. specaug=specaug,
  585. normalize=normalize,
  586. preencoder=preencoder,
  587. encoder=encoder,
  588. postencoder=postencoder,
  589. decoder=decoder,
  590. ctc=ctc,
  591. ctc_weight=ctc_weight,
  592. interctc_weight=interctc_weight,
  593. ignore_id=ignore_id,
  594. blank_id=blank_id,
  595. sos=sos,
  596. eos=eos,
  597. lsm_weight=lsm_weight,
  598. length_normalized_loss=length_normalized_loss,
  599. report_cer=report_cer,
  600. report_wer=report_wer,
  601. sym_space=sym_space,
  602. sym_blank=sym_blank,
  603. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  604. predictor=predictor,
  605. predictor_weight=predictor_weight,
  606. predictor_bias=predictor_bias,
  607. sampling_ratio=sampling_ratio,
  608. )
  609. # note that eos is the same as sos (equivalent ID)
  610. self.blank_id = blank_id
  611. self.sos = vocab_size - 1 if sos is None else sos
  612. self.eos = vocab_size - 1 if eos is None else eos
  613. self.vocab_size = vocab_size
  614. self.ignore_id = ignore_id
  615. self.ctc_weight = ctc_weight
  616. self.interctc_weight = interctc_weight
  617. self.token_list = token_list.copy()
  618. self.frontend = frontend
  619. self.specaug = specaug
  620. self.normalize = normalize
  621. self.preencoder = preencoder
  622. self.postencoder = postencoder
  623. self.encoder = encoder
  624. if not hasattr(self.encoder, "interctc_use_conditioning"):
  625. self.encoder.interctc_use_conditioning = False
  626. if self.encoder.interctc_use_conditioning:
  627. self.encoder.conditioning_layer = torch.nn.Linear(
  628. vocab_size, self.encoder.output_size()
  629. )
  630. self.error_calculator = None
  631. if ctc_weight == 1.0:
  632. self.decoder = None
  633. else:
  634. self.decoder = decoder
  635. self.criterion_att = LabelSmoothingLoss(
  636. size=vocab_size,
  637. padding_idx=ignore_id,
  638. smoothing=lsm_weight,
  639. normalize_length=length_normalized_loss,
  640. )
  641. if report_cer or report_wer:
  642. self.error_calculator = ErrorCalculator(
  643. token_list, sym_space, sym_blank, report_cer, report_wer
  644. )
  645. if ctc_weight == 0.0:
  646. self.ctc = None
  647. else:
  648. self.ctc = ctc
  649. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  650. self.predictor = predictor
  651. self.predictor_weight = predictor_weight
  652. self.predictor_bias = predictor_bias
  653. self.sampling_ratio = sampling_ratio
  654. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  655. self.step_cur = 0
  656. self.scama_mask = None
  657. if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
  658. from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
  659. self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
  660. self.decoder_attention_chunk_type = decoder_attention_chunk_type
  661. self.share_embedding = share_embedding
  662. if self.share_embedding:
  663. self.decoder.embed = None
  664. self.use_1st_decoder_loss = use_1st_decoder_loss
  665. def forward(
  666. self,
  667. speech: torch.Tensor,
  668. speech_lengths: torch.Tensor,
  669. text: torch.Tensor,
  670. text_lengths: torch.Tensor,
  671. decoding_ind: int = None,
  672. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  673. """Frontend + Encoder + Decoder + Calc loss
  674. Args:
  675. speech: (Batch, Length, ...)
  676. speech_lengths: (Batch, )
  677. text: (Batch, Length)
  678. text_lengths: (Batch,)
  679. decoding_ind: int
  680. """
  681. assert text_lengths.dim() == 1, text_lengths.shape
  682. # Check that batch_size is unified
  683. assert (
  684. speech.shape[0]
  685. == speech_lengths.shape[0]
  686. == text.shape[0]
  687. == text_lengths.shape[0]
  688. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  689. batch_size = speech.shape[0]
  690. self.step_cur += 1
  691. # for data-parallel
  692. text = text[:, : text_lengths.max()]
  693. speech = speech[:, :speech_lengths.max()]
  694. # 1. Encoder
  695. if hasattr(self.encoder, "overlap_chunk_cls"):
  696. ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
  697. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
  698. else:
  699. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  700. intermediate_outs = None
  701. if isinstance(encoder_out, tuple):
  702. intermediate_outs = encoder_out[1]
  703. encoder_out = encoder_out[0]
  704. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  705. loss_ctc, cer_ctc = None, None
  706. loss_pre = None
  707. stats = dict()
  708. # 1. CTC branch
  709. if self.ctc_weight != 0.0:
  710. if hasattr(self.encoder, "overlap_chunk_cls"):
  711. encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  712. encoder_out_lens,
  713. chunk_outs=None)
  714. loss_ctc, cer_ctc = self._calc_ctc_loss(
  715. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  716. )
  717. # Collect CTC branch stats
  718. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  719. stats["cer_ctc"] = cer_ctc
  720. # Intermediate CTC (optional)
  721. loss_interctc = 0.0
  722. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  723. for layer_idx, intermediate_out in intermediate_outs:
  724. # we assume intermediate_out has the same length & padding
  725. # as those of encoder_out
  726. if hasattr(self.encoder, "overlap_chunk_cls"):
  727. encoder_out_ctc, encoder_out_lens_ctc = \
  728. self.encoder.overlap_chunk_cls.remove_chunk(
  729. intermediate_out,
  730. encoder_out_lens,
  731. chunk_outs=None)
  732. loss_ic, cer_ic = self._calc_ctc_loss(
  733. encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
  734. )
  735. loss_interctc = loss_interctc + loss_ic
  736. # Collect Intermedaite CTC stats
  737. stats["loss_interctc_layer{}".format(layer_idx)] = (
  738. loss_ic.detach() if loss_ic is not None else None
  739. )
  740. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  741. loss_interctc = loss_interctc / len(intermediate_outs)
  742. # calculate whole encoder loss
  743. loss_ctc = (
  744. 1 - self.interctc_weight
  745. ) * loss_ctc + self.interctc_weight * loss_interctc
  746. # 2b. Attention decoder branch
  747. if self.ctc_weight != 1.0:
  748. loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
  749. encoder_out, encoder_out_lens, text, text_lengths
  750. )
  751. # 3. CTC-Att loss definition
  752. if self.ctc_weight == 0.0:
  753. loss = loss_att + loss_pre * self.predictor_weight
  754. elif self.ctc_weight == 1.0:
  755. loss = loss_ctc
  756. else:
  757. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  758. if self.use_1st_decoder_loss and pre_loss_att is not None:
  759. loss = loss + pre_loss_att
  760. # Collect Attn branch stats
  761. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  762. stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
  763. stats["acc"] = acc_att
  764. stats["cer"] = cer_att
  765. stats["wer"] = wer_att
  766. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  767. stats["loss"] = torch.clone(loss.detach())
  768. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  769. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  770. return loss, stats, weight
  771. def encode(
  772. self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
  773. ) -> Tuple[torch.Tensor, torch.Tensor]:
  774. """Frontend + Encoder. Note that this method is used by asr_inference.py
  775. Args:
  776. speech: (Batch, Length, ...)
  777. speech_lengths: (Batch, )
  778. """
  779. with autocast(False):
  780. # 1. Extract feats
  781. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  782. # 2. Data augmentation
  783. if self.specaug is not None and self.training:
  784. feats, feats_lengths = self.specaug(feats, feats_lengths)
  785. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  786. if self.normalize is not None:
  787. feats, feats_lengths = self.normalize(feats, feats_lengths)
  788. # Pre-encoder, e.g. used for raw input data
  789. if self.preencoder is not None:
  790. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  791. # 4. Forward encoder
  792. # feats: (Batch, Length, Dim)
  793. # -> encoder_out: (Batch, Length2, Dim2)
  794. if self.encoder.interctc_use_conditioning:
  795. encoder_out, encoder_out_lens, _ = self.encoder(
  796. feats, feats_lengths, ctc=self.ctc, ind=ind
  797. )
  798. else:
  799. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
  800. intermediate_outs = None
  801. if isinstance(encoder_out, tuple):
  802. intermediate_outs = encoder_out[1]
  803. encoder_out = encoder_out[0]
  804. # Post-encoder, e.g. NLU
  805. if self.postencoder is not None:
  806. encoder_out, encoder_out_lens = self.postencoder(
  807. encoder_out, encoder_out_lens
  808. )
  809. assert encoder_out.size(0) == speech.size(0), (
  810. encoder_out.size(),
  811. speech.size(0),
  812. )
  813. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  814. encoder_out.size(),
  815. encoder_out_lens.max(),
  816. )
  817. if intermediate_outs is not None:
  818. return (encoder_out, intermediate_outs), encoder_out_lens
  819. return encoder_out, encoder_out_lens
  820. def encode_chunk(
  821. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
  822. ) -> Tuple[torch.Tensor, torch.Tensor]:
  823. """Frontend + Encoder. Note that this method is used by asr_inference.py
  824. Args:
  825. speech: (Batch, Length, ...)
  826. speech_lengths: (Batch, )
  827. """
  828. with autocast(False):
  829. # 1. Extract feats
  830. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  831. # 2. Data augmentation
  832. if self.specaug is not None and self.training:
  833. feats, feats_lengths = self.specaug(feats, feats_lengths)
  834. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  835. if self.normalize is not None:
  836. feats, feats_lengths = self.normalize(feats, feats_lengths)
  837. # Pre-encoder, e.g. used for raw input data
  838. if self.preencoder is not None:
  839. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  840. # 4. Forward encoder
  841. # feats: (Batch, Length, Dim)
  842. # -> encoder_out: (Batch, Length2, Dim2)
  843. if self.encoder.interctc_use_conditioning:
  844. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
  845. feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
  846. )
  847. else:
  848. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
  849. intermediate_outs = None
  850. if isinstance(encoder_out, tuple):
  851. intermediate_outs = encoder_out[1]
  852. encoder_out = encoder_out[0]
  853. # Post-encoder, e.g. NLU
  854. if self.postencoder is not None:
  855. encoder_out, encoder_out_lens = self.postencoder(
  856. encoder_out, encoder_out_lens
  857. )
  858. if intermediate_outs is not None:
  859. return (encoder_out, intermediate_outs), encoder_out_lens
  860. return encoder_out, torch.tensor([encoder_out.size(1)])
  861. def _calc_att_predictor_loss(
  862. self,
  863. encoder_out: torch.Tensor,
  864. encoder_out_lens: torch.Tensor,
  865. ys_pad: torch.Tensor,
  866. ys_pad_lens: torch.Tensor,
  867. ):
  868. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  869. encoder_out.device)
  870. if self.predictor_bias == 1:
  871. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  872. ys_pad_lens = ys_pad_lens + self.predictor_bias
  873. mask_chunk_predictor = None
  874. if self.encoder.overlap_chunk_cls is not None:
  875. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  876. device=encoder_out.device,
  877. batch_size=encoder_out.size(
  878. 0))
  879. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  880. batch_size=encoder_out.size(0))
  881. encoder_out = encoder_out * mask_shfit_chunk
  882. pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
  883. ys_pad,
  884. encoder_out_mask,
  885. ignore_id=self.ignore_id,
  886. mask_chunk_predictor=mask_chunk_predictor,
  887. target_label_length=ys_pad_lens,
  888. )
  889. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  890. encoder_out_lens)
  891. scama_mask = None
  892. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  893. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  894. attention_chunk_center_bias = 0
  895. attention_chunk_size = encoder_chunk_size
  896. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  897. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
  898. get_mask_shift_att_chunk_decoder(None,
  899. device=encoder_out.device,
  900. batch_size=encoder_out.size(0)
  901. )
  902. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  903. predictor_alignments=predictor_alignments,
  904. encoder_sequence_length=encoder_out_lens,
  905. chunk_size=1,
  906. encoder_chunk_size=encoder_chunk_size,
  907. attention_chunk_center_bias=attention_chunk_center_bias,
  908. attention_chunk_size=attention_chunk_size,
  909. attention_chunk_type=self.decoder_attention_chunk_type,
  910. step=None,
  911. predictor_mask_chunk_hopping=mask_chunk_predictor,
  912. decoder_att_look_back_factor=decoder_att_look_back_factor,
  913. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  914. target_length=ys_pad_lens,
  915. is_training=self.training,
  916. )
  917. elif self.encoder.overlap_chunk_cls is not None:
  918. encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
  919. encoder_out_lens,
  920. chunk_outs=None)
  921. # 0. sampler
  922. decoder_out_1st = None
  923. pre_loss_att = None
  924. if self.sampling_ratio > 0.0:
  925. if self.step_cur < 2:
  926. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  927. if self.use_1st_decoder_loss:
  928. sematic_embeds, decoder_out_1st, pre_loss_att = \
  929. self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
  930. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  931. else:
  932. sematic_embeds, decoder_out_1st = \
  933. self.sampler(encoder_out, encoder_out_lens, ys_pad,
  934. ys_pad_lens, pre_acoustic_embeds, scama_mask)
  935. else:
  936. if self.step_cur < 2:
  937. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  938. sematic_embeds = pre_acoustic_embeds
  939. # 1. Forward decoder
  940. decoder_outs = self.decoder(
  941. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
  942. )
  943. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  944. if decoder_out_1st is None:
  945. decoder_out_1st = decoder_out
  946. # 2. Compute attention loss
  947. loss_att = self.criterion_att(decoder_out, ys_pad)
  948. acc_att = th_accuracy(
  949. decoder_out_1st.view(-1, self.vocab_size),
  950. ys_pad,
  951. ignore_label=self.ignore_id,
  952. )
  953. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  954. # Compute cer/wer using attention-decoder
  955. if self.training or self.error_calculator is None:
  956. cer_att, wer_att = None, None
  957. else:
  958. ys_hat = decoder_out_1st.argmax(dim=-1)
  959. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  960. return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
  961. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
  962. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  963. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  964. if self.share_embedding:
  965. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  966. else:
  967. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  968. with torch.no_grad():
  969. decoder_outs = self.decoder(
  970. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
  971. )
  972. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  973. pred_tokens = decoder_out.argmax(-1)
  974. nonpad_positions = ys_pad.ne(self.ignore_id)
  975. seq_lens = (nonpad_positions).sum(1)
  976. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  977. input_mask = torch.ones_like(nonpad_positions)
  978. bsz, seq_len = ys_pad.size()
  979. for li in range(bsz):
  980. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  981. if target_num > 0:
  982. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  983. input_mask = input_mask.eq(1)
  984. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  985. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  986. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  987. input_mask_expand_dim, 0)
  988. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  989. def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
  990. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  991. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  992. if self.share_embedding:
  993. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  994. else:
  995. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  996. decoder_outs = self.decoder(
  997. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
  998. )
  999. pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
  1000. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1001. pred_tokens = decoder_out.argmax(-1)
  1002. nonpad_positions = ys_pad.ne(self.ignore_id)
  1003. seq_lens = (nonpad_positions).sum(1)
  1004. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  1005. input_mask = torch.ones_like(nonpad_positions)
  1006. bsz, seq_len = ys_pad.size()
  1007. for li in range(bsz):
  1008. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  1009. if target_num > 0:
  1010. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  1011. input_mask = input_mask.eq(1)
  1012. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  1013. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  1014. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  1015. input_mask_expand_dim, 0)
  1016. return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
  1017. def calc_predictor(self, encoder_out, encoder_out_lens):
  1018. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1019. encoder_out.device)
  1020. mask_chunk_predictor = None
  1021. if self.encoder.overlap_chunk_cls is not None:
  1022. mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
  1023. device=encoder_out.device,
  1024. batch_size=encoder_out.size(
  1025. 0))
  1026. mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
  1027. batch_size=encoder_out.size(0))
  1028. encoder_out = encoder_out * mask_shfit_chunk
  1029. pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
  1030. None,
  1031. encoder_out_mask,
  1032. ignore_id=self.ignore_id,
  1033. mask_chunk_predictor=mask_chunk_predictor,
  1034. target_label_length=None,
  1035. )
  1036. predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
  1037. encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
  1038. scama_mask = None
  1039. if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
  1040. encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
  1041. attention_chunk_center_bias = 0
  1042. attention_chunk_size = encoder_chunk_size
  1043. decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
  1044. mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
  1045. get_mask_shift_att_chunk_decoder(None,
  1046. device=encoder_out.device,
  1047. batch_size=encoder_out.size(0)
  1048. )
  1049. scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
  1050. predictor_alignments=predictor_alignments,
  1051. encoder_sequence_length=encoder_out_lens,
  1052. chunk_size=1,
  1053. encoder_chunk_size=encoder_chunk_size,
  1054. attention_chunk_center_bias=attention_chunk_center_bias,
  1055. attention_chunk_size=attention_chunk_size,
  1056. attention_chunk_type=self.decoder_attention_chunk_type,
  1057. step=None,
  1058. predictor_mask_chunk_hopping=mask_chunk_predictor,
  1059. decoder_att_look_back_factor=decoder_att_look_back_factor,
  1060. mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
  1061. target_length=None,
  1062. is_training=self.training,
  1063. )
  1064. self.scama_mask = scama_mask
  1065. return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
  1066. def calc_predictor_chunk(self, encoder_out, cache=None):
  1067. pre_acoustic_embeds, pre_token_length = \
  1068. self.predictor.forward_chunk(encoder_out, cache["encoder"])
  1069. return pre_acoustic_embeds, pre_token_length
  1070. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  1071. decoder_outs = self.decoder(
  1072. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
  1073. )
  1074. decoder_out = decoder_outs[0]
  1075. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  1076. return decoder_out, ys_pad_lens
  1077. def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
  1078. decoder_outs = self.decoder.forward_chunk(
  1079. encoder_out, sematic_embeds, cache["decoder"]
  1080. )
  1081. decoder_out = decoder_outs
  1082. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  1083. return decoder_out
  1084. class ParaformerBert(Paraformer):
  1085. """
  1086. Author: Speech Lab of DAMO Academy, Alibaba Group
  1087. Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
  1088. """
  1089. def __init__(
  1090. self,
  1091. vocab_size: int,
  1092. token_list: Union[Tuple[str, ...], List[str]],
  1093. frontend: Optional[AbsFrontend],
  1094. specaug: Optional[AbsSpecAug],
  1095. normalize: Optional[AbsNormalize],
  1096. encoder: AbsEncoder,
  1097. decoder: AbsDecoder,
  1098. ctc: CTC,
  1099. ctc_weight: float = 0.5,
  1100. interctc_weight: float = 0.0,
  1101. ignore_id: int = -1,
  1102. blank_id: int = 0,
  1103. sos: int = 1,
  1104. eos: int = 2,
  1105. lsm_weight: float = 0.0,
  1106. length_normalized_loss: bool = False,
  1107. report_cer: bool = True,
  1108. report_wer: bool = True,
  1109. sym_space: str = "<space>",
  1110. sym_blank: str = "<blank>",
  1111. extract_feats_in_collect_stats: bool = True,
  1112. predictor=None,
  1113. predictor_weight: float = 0.0,
  1114. predictor_bias: int = 0,
  1115. sampling_ratio: float = 0.2,
  1116. embeds_id: int = 2,
  1117. embeds_loss_weight: float = 0.0,
  1118. embed_dims: int = 768,
  1119. preencoder: Optional[AbsPreEncoder] = None,
  1120. postencoder: Optional[AbsPostEncoder] = None,
  1121. ):
  1122. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  1123. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  1124. super().__init__(
  1125. vocab_size=vocab_size,
  1126. token_list=token_list,
  1127. frontend=frontend,
  1128. specaug=specaug,
  1129. normalize=normalize,
  1130. preencoder=preencoder,
  1131. encoder=encoder,
  1132. postencoder=postencoder,
  1133. decoder=decoder,
  1134. ctc=ctc,
  1135. ctc_weight=ctc_weight,
  1136. interctc_weight=interctc_weight,
  1137. ignore_id=ignore_id,
  1138. blank_id=blank_id,
  1139. sos=sos,
  1140. eos=eos,
  1141. lsm_weight=lsm_weight,
  1142. length_normalized_loss=length_normalized_loss,
  1143. report_cer=report_cer,
  1144. report_wer=report_wer,
  1145. sym_space=sym_space,
  1146. sym_blank=sym_blank,
  1147. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  1148. predictor=predictor,
  1149. predictor_weight=predictor_weight,
  1150. predictor_bias=predictor_bias,
  1151. sampling_ratio=sampling_ratio,
  1152. )
  1153. self.decoder.embeds_id = embeds_id
  1154. decoder_attention_dim = self.decoder.attention_dim
  1155. self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
  1156. self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
  1157. self.embeds_loss_weight = embeds_loss_weight
  1158. self.length_normalized_loss = length_normalized_loss
  1159. def _calc_embed_loss(self,
  1160. ys_pad: torch.Tensor,
  1161. ys_pad_lens: torch.Tensor,
  1162. embed: torch.Tensor = None,
  1163. embed_lengths: torch.Tensor = None,
  1164. embeds_outputs: torch.Tensor = None,
  1165. ):
  1166. embeds_outputs = self.pro_nn(embeds_outputs)
  1167. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  1168. embeds_outputs *= tgt_mask # b x l x d
  1169. embed *= tgt_mask # b x l x d
  1170. cos_loss = 1.0 - self.cos(embeds_outputs, embed)
  1171. cos_loss *= tgt_mask.squeeze(2)
  1172. if self.length_normalized_loss:
  1173. token_num_total = torch.sum(tgt_mask)
  1174. else:
  1175. token_num_total = tgt_mask.size()[0]
  1176. cos_loss_total = torch.sum(cos_loss)
  1177. cos_loss = cos_loss_total / token_num_total
  1178. # print("cos_loss: {}".format(cos_loss))
  1179. return cos_loss
  1180. def _calc_att_loss(
  1181. self,
  1182. encoder_out: torch.Tensor,
  1183. encoder_out_lens: torch.Tensor,
  1184. ys_pad: torch.Tensor,
  1185. ys_pad_lens: torch.Tensor,
  1186. ):
  1187. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1188. encoder_out.device)
  1189. if self.predictor_bias == 1:
  1190. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1191. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1192. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  1193. ignore_id=self.ignore_id)
  1194. # 0. sampler
  1195. decoder_out_1st = None
  1196. if self.sampling_ratio > 0.0:
  1197. if self.step_cur < 2:
  1198. logging.info(
  1199. "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1200. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  1201. pre_acoustic_embeds)
  1202. else:
  1203. if self.step_cur < 2:
  1204. logging.info(
  1205. "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1206. sematic_embeds = pre_acoustic_embeds
  1207. # 1. Forward decoder
  1208. decoder_outs = self.decoder(
  1209. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  1210. )
  1211. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1212. embeds_outputs = None
  1213. if len(decoder_outs) > 2:
  1214. embeds_outputs = decoder_outs[2]
  1215. if decoder_out_1st is None:
  1216. decoder_out_1st = decoder_out
  1217. # 2. Compute attention loss
  1218. loss_att = self.criterion_att(decoder_out, ys_pad)
  1219. acc_att = th_accuracy(
  1220. decoder_out_1st.view(-1, self.vocab_size),
  1221. ys_pad,
  1222. ignore_label=self.ignore_id,
  1223. )
  1224. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1225. # Compute cer/wer using attention-decoder
  1226. if self.training or self.error_calculator is None:
  1227. cer_att, wer_att = None, None
  1228. else:
  1229. ys_hat = decoder_out_1st.argmax(dim=-1)
  1230. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1231. return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
  1232. def forward(
  1233. self,
  1234. speech: torch.Tensor,
  1235. speech_lengths: torch.Tensor,
  1236. text: torch.Tensor,
  1237. text_lengths: torch.Tensor,
  1238. embed: torch.Tensor = None,
  1239. embed_lengths: torch.Tensor = None,
  1240. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1241. """Frontend + Encoder + Decoder + Calc loss
  1242. Args:
  1243. speech: (Batch, Length, ...)
  1244. speech_lengths: (Batch, )
  1245. text: (Batch, Length)
  1246. text_lengths: (Batch,)
  1247. """
  1248. assert text_lengths.dim() == 1, text_lengths.shape
  1249. # Check that batch_size is unified
  1250. assert (
  1251. speech.shape[0]
  1252. == speech_lengths.shape[0]
  1253. == text.shape[0]
  1254. == text_lengths.shape[0]
  1255. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1256. batch_size = speech.shape[0]
  1257. self.step_cur += 1
  1258. # for data-parallel
  1259. text = text[:, : text_lengths.max()]
  1260. speech = speech[:, :speech_lengths.max()]
  1261. if embed is not None:
  1262. embed = embed[:, :embed_lengths.max()]
  1263. # 1. Encoder
  1264. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1265. intermediate_outs = None
  1266. if isinstance(encoder_out, tuple):
  1267. intermediate_outs = encoder_out[1]
  1268. encoder_out = encoder_out[0]
  1269. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1270. loss_ctc, cer_ctc = None, None
  1271. loss_pre = 0.0
  1272. cos_loss = 0.0
  1273. stats = dict()
  1274. # 1. CTC branch
  1275. if self.ctc_weight != 0.0:
  1276. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1277. encoder_out, encoder_out_lens, text, text_lengths
  1278. )
  1279. # Collect CTC branch stats
  1280. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1281. stats["cer_ctc"] = cer_ctc
  1282. # Intermediate CTC (optional)
  1283. loss_interctc = 0.0
  1284. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1285. for layer_idx, intermediate_out in intermediate_outs:
  1286. # we assume intermediate_out has the same length & padding
  1287. # as those of encoder_out
  1288. loss_ic, cer_ic = self._calc_ctc_loss(
  1289. intermediate_out, encoder_out_lens, text, text_lengths
  1290. )
  1291. loss_interctc = loss_interctc + loss_ic
  1292. # Collect Intermedaite CTC stats
  1293. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1294. loss_ic.detach() if loss_ic is not None else None
  1295. )
  1296. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1297. loss_interctc = loss_interctc / len(intermediate_outs)
  1298. # calculate whole encoder loss
  1299. loss_ctc = (
  1300. 1 - self.interctc_weight
  1301. ) * loss_ctc + self.interctc_weight * loss_interctc
  1302. # 2b. Attention decoder branch
  1303. if self.ctc_weight != 1.0:
  1304. loss_ret = self._calc_att_loss(
  1305. encoder_out, encoder_out_lens, text, text_lengths
  1306. )
  1307. loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
  1308. loss_ret[4]
  1309. embeds_outputs = None
  1310. if len(loss_ret) > 5:
  1311. embeds_outputs = loss_ret[5]
  1312. if embeds_outputs is not None:
  1313. cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
  1314. # 3. CTC-Att loss definition
  1315. if self.ctc_weight == 0.0:
  1316. loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  1317. elif self.ctc_weight == 1.0:
  1318. loss = loss_ctc
  1319. else:
  1320. loss = self.ctc_weight * loss_ctc + (
  1321. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  1322. # Collect Attn branch stats
  1323. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1324. stats["acc"] = acc_att
  1325. stats["cer"] = cer_att
  1326. stats["wer"] = wer_att
  1327. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
  1328. stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
  1329. stats["loss"] = torch.clone(loss.detach())
  1330. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1331. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1332. return loss, stats, weight
  1333. class BiCifParaformer(Paraformer):
  1334. """
  1335. Paraformer model with an extra cif predictor
  1336. to conduct accurate timestamp prediction
  1337. """
  1338. def __init__(
  1339. self,
  1340. vocab_size: int,
  1341. token_list: Union[Tuple[str, ...], List[str]],
  1342. frontend: Optional[AbsFrontend],
  1343. specaug: Optional[AbsSpecAug],
  1344. normalize: Optional[AbsNormalize],
  1345. encoder: AbsEncoder,
  1346. decoder: AbsDecoder,
  1347. ctc: CTC,
  1348. ctc_weight: float = 0.5,
  1349. interctc_weight: float = 0.0,
  1350. ignore_id: int = -1,
  1351. blank_id: int = 0,
  1352. sos: int = 1,
  1353. eos: int = 2,
  1354. lsm_weight: float = 0.0,
  1355. length_normalized_loss: bool = False,
  1356. report_cer: bool = True,
  1357. report_wer: bool = True,
  1358. sym_space: str = "<space>",
  1359. sym_blank: str = "<blank>",
  1360. extract_feats_in_collect_stats: bool = True,
  1361. predictor=None,
  1362. predictor_weight: float = 0.0,
  1363. predictor_bias: int = 0,
  1364. sampling_ratio: float = 0.2,
  1365. preencoder: Optional[AbsPreEncoder] = None,
  1366. postencoder: Optional[AbsPostEncoder] = None,
  1367. ):
  1368. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  1369. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  1370. super().__init__(
  1371. vocab_size=vocab_size,
  1372. token_list=token_list,
  1373. frontend=frontend,
  1374. specaug=specaug,
  1375. normalize=normalize,
  1376. preencoder=preencoder,
  1377. encoder=encoder,
  1378. postencoder=postencoder,
  1379. decoder=decoder,
  1380. ctc=ctc,
  1381. ctc_weight=ctc_weight,
  1382. interctc_weight=interctc_weight,
  1383. ignore_id=ignore_id,
  1384. blank_id=blank_id,
  1385. sos=sos,
  1386. eos=eos,
  1387. lsm_weight=lsm_weight,
  1388. length_normalized_loss=length_normalized_loss,
  1389. report_cer=report_cer,
  1390. report_wer=report_wer,
  1391. sym_space=sym_space,
  1392. sym_blank=sym_blank,
  1393. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  1394. predictor=predictor,
  1395. predictor_weight=predictor_weight,
  1396. predictor_bias=predictor_bias,
  1397. sampling_ratio=sampling_ratio,
  1398. )
  1399. assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
  1400. def _calc_pre2_loss(
  1401. self,
  1402. encoder_out: torch.Tensor,
  1403. encoder_out_lens: torch.Tensor,
  1404. ys_pad: torch.Tensor,
  1405. ys_pad_lens: torch.Tensor,
  1406. ):
  1407. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1408. encoder_out.device)
  1409. if self.predictor_bias == 1:
  1410. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1411. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1412. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  1413. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1414. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  1415. return loss_pre2
  1416. def _calc_att_loss(
  1417. self,
  1418. encoder_out: torch.Tensor,
  1419. encoder_out_lens: torch.Tensor,
  1420. ys_pad: torch.Tensor,
  1421. ys_pad_lens: torch.Tensor,
  1422. ):
  1423. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1424. encoder_out.device)
  1425. if self.predictor_bias == 1:
  1426. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1427. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1428. pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  1429. ignore_id=self.ignore_id)
  1430. # 0. sampler
  1431. decoder_out_1st = None
  1432. if self.sampling_ratio > 0.0:
  1433. if self.step_cur < 2:
  1434. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1435. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  1436. pre_acoustic_embeds)
  1437. else:
  1438. if self.step_cur < 2:
  1439. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1440. sematic_embeds = pre_acoustic_embeds
  1441. # 1. Forward decoder
  1442. decoder_outs = self.decoder(
  1443. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  1444. )
  1445. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1446. if decoder_out_1st is None:
  1447. decoder_out_1st = decoder_out
  1448. # 2. Compute attention loss
  1449. loss_att = self.criterion_att(decoder_out, ys_pad)
  1450. acc_att = th_accuracy(
  1451. decoder_out_1st.view(-1, self.vocab_size),
  1452. ys_pad,
  1453. ignore_label=self.ignore_id,
  1454. )
  1455. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1456. # Compute cer/wer using attention-decoder
  1457. if self.training or self.error_calculator is None:
  1458. cer_att, wer_att = None, None
  1459. else:
  1460. ys_hat = decoder_out_1st.argmax(dim=-1)
  1461. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1462. return loss_att, acc_att, cer_att, wer_att, loss_pre
  1463. def calc_predictor(self, encoder_out, encoder_out_lens):
  1464. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1465. encoder_out.device)
  1466. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
  1467. None,
  1468. encoder_out_mask,
  1469. ignore_id=self.ignore_id)
  1470. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  1471. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  1472. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1473. encoder_out.device)
  1474. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  1475. encoder_out_mask,
  1476. token_num)
  1477. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  1478. def forward(
  1479. self,
  1480. speech: torch.Tensor,
  1481. speech_lengths: torch.Tensor,
  1482. text: torch.Tensor,
  1483. text_lengths: torch.Tensor,
  1484. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1485. """Frontend + Encoder + Decoder + Calc loss
  1486. Args:
  1487. speech: (Batch, Length, ...)
  1488. speech_lengths: (Batch, )
  1489. text: (Batch, Length)
  1490. text_lengths: (Batch,)
  1491. """
  1492. assert text_lengths.dim() == 1, text_lengths.shape
  1493. # Check that batch_size is unified
  1494. assert (
  1495. speech.shape[0]
  1496. == speech_lengths.shape[0]
  1497. == text.shape[0]
  1498. == text_lengths.shape[0]
  1499. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1500. batch_size = speech.shape[0]
  1501. self.step_cur += 1
  1502. # for data-parallel
  1503. text = text[:, : text_lengths.max()]
  1504. speech = speech[:, :speech_lengths.max()]
  1505. # 1. Encoder
  1506. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1507. intermediate_outs = None
  1508. if isinstance(encoder_out, tuple):
  1509. intermediate_outs = encoder_out[1]
  1510. encoder_out = encoder_out[0]
  1511. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1512. loss_ctc, cer_ctc = None, None
  1513. loss_pre = None
  1514. stats = dict()
  1515. # 1. CTC branch
  1516. if self.ctc_weight != 0.0:
  1517. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1518. encoder_out, encoder_out_lens, text, text_lengths
  1519. )
  1520. # Collect CTC branch stats
  1521. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1522. stats["cer_ctc"] = cer_ctc
  1523. # Intermediate CTC (optional)
  1524. loss_interctc = 0.0
  1525. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1526. for layer_idx, intermediate_out in intermediate_outs:
  1527. # we assume intermediate_out has the same length & padding
  1528. # as those of encoder_out
  1529. loss_ic, cer_ic = self._calc_ctc_loss(
  1530. intermediate_out, encoder_out_lens, text, text_lengths
  1531. )
  1532. loss_interctc = loss_interctc + loss_ic
  1533. # Collect Intermedaite CTC stats
  1534. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1535. loss_ic.detach() if loss_ic is not None else None
  1536. )
  1537. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1538. loss_interctc = loss_interctc / len(intermediate_outs)
  1539. # calculate whole encoder loss
  1540. loss_ctc = (
  1541. 1 - self.interctc_weight
  1542. ) * loss_ctc + self.interctc_weight * loss_interctc
  1543. # 2b. Attention decoder branch
  1544. if self.ctc_weight != 1.0:
  1545. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  1546. encoder_out, encoder_out_lens, text, text_lengths
  1547. )
  1548. loss_pre2 = self._calc_pre2_loss(
  1549. encoder_out, encoder_out_lens, text, text_lengths
  1550. )
  1551. # 3. CTC-Att loss definition
  1552. if self.ctc_weight == 0.0:
  1553. loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  1554. elif self.ctc_weight == 1.0:
  1555. loss = loss_ctc
  1556. else:
  1557. loss = self.ctc_weight * loss_ctc + (
  1558. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  1559. # Collect Attn branch stats
  1560. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1561. stats["acc"] = acc_att
  1562. stats["cer"] = cer_att
  1563. stats["wer"] = wer_att
  1564. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  1565. stats["loss_pre2"] = loss_pre2.detach().cpu()
  1566. stats["loss"] = torch.clone(loss.detach())
  1567. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1568. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1569. return loss, stats, weight
  1570. class ContextualParaformer(Paraformer):
  1571. """
  1572. Paraformer model with contextual hotword
  1573. """
  1574. def __init__(
  1575. self,
  1576. vocab_size: int,
  1577. token_list: Union[Tuple[str, ...], List[str]],
  1578. frontend: Optional[AbsFrontend],
  1579. specaug: Optional[AbsSpecAug],
  1580. normalize: Optional[AbsNormalize],
  1581. encoder: AbsEncoder,
  1582. decoder: AbsDecoder,
  1583. ctc: CTC,
  1584. ctc_weight: float = 0.5,
  1585. interctc_weight: float = 0.0,
  1586. ignore_id: int = -1,
  1587. blank_id: int = 0,
  1588. sos: int = 1,
  1589. eos: int = 2,
  1590. lsm_weight: float = 0.0,
  1591. length_normalized_loss: bool = False,
  1592. report_cer: bool = True,
  1593. report_wer: bool = True,
  1594. sym_space: str = "<space>",
  1595. sym_blank: str = "<blank>",
  1596. extract_feats_in_collect_stats: bool = True,
  1597. predictor=None,
  1598. predictor_weight: float = 0.0,
  1599. predictor_bias: int = 0,
  1600. sampling_ratio: float = 0.2,
  1601. min_hw_length: int = 2,
  1602. max_hw_length: int = 4,
  1603. sample_rate: float = 0.6,
  1604. batch_rate: float = 0.5,
  1605. double_rate: float = -1.0,
  1606. target_buffer_length: int = -1,
  1607. inner_dim: int = 256,
  1608. bias_encoder_type: str = 'lstm',
  1609. label_bracket: bool = False,
  1610. use_decoder_embedding: bool = False,
  1611. preencoder: Optional[AbsPreEncoder] = None,
  1612. postencoder: Optional[AbsPostEncoder] = None,
  1613. ):
  1614. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  1615. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  1616. super().__init__(
  1617. vocab_size=vocab_size,
  1618. token_list=token_list,
  1619. frontend=frontend,
  1620. specaug=specaug,
  1621. normalize=normalize,
  1622. preencoder=preencoder,
  1623. encoder=encoder,
  1624. postencoder=postencoder,
  1625. decoder=decoder,
  1626. ctc=ctc,
  1627. ctc_weight=ctc_weight,
  1628. interctc_weight=interctc_weight,
  1629. ignore_id=ignore_id,
  1630. blank_id=blank_id,
  1631. sos=sos,
  1632. eos=eos,
  1633. lsm_weight=lsm_weight,
  1634. length_normalized_loss=length_normalized_loss,
  1635. report_cer=report_cer,
  1636. report_wer=report_wer,
  1637. sym_space=sym_space,
  1638. sym_blank=sym_blank,
  1639. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  1640. predictor=predictor,
  1641. predictor_weight=predictor_weight,
  1642. predictor_bias=predictor_bias,
  1643. sampling_ratio=sampling_ratio,
  1644. )
  1645. if bias_encoder_type == 'lstm':
  1646. logging.warning("enable bias encoder sampling and contextual training")
  1647. self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
  1648. self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
  1649. else:
  1650. logging.error("Unsupport bias encoder type")
  1651. self.min_hw_length = min_hw_length
  1652. self.max_hw_length = max_hw_length
  1653. self.sample_rate = sample_rate
  1654. self.batch_rate = batch_rate
  1655. self.target_buffer_length = target_buffer_length
  1656. self.double_rate = double_rate
  1657. if self.target_buffer_length > 0:
  1658. self.hotword_buffer = None
  1659. self.length_record = []
  1660. self.current_buffer_length = 0
  1661. self.use_decoder_embedding = use_decoder_embedding
  1662. def forward(
  1663. self,
  1664. speech: torch.Tensor,
  1665. speech_lengths: torch.Tensor,
  1666. text: torch.Tensor,
  1667. text_lengths: torch.Tensor,
  1668. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1669. """Frontend + Encoder + Decoder + Calc loss
  1670. Args:
  1671. speech: (Batch, Length, ...)
  1672. speech_lengths: (Batch, )
  1673. text: (Batch, Length)
  1674. text_lengths: (Batch,)
  1675. """
  1676. assert text_lengths.dim() == 1, text_lengths.shape
  1677. # Check that batch_size is unified
  1678. assert (
  1679. speech.shape[0]
  1680. == speech_lengths.shape[0]
  1681. == text.shape[0]
  1682. == text_lengths.shape[0]
  1683. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1684. batch_size = speech.shape[0]
  1685. self.step_cur += 1
  1686. # for data-parallel
  1687. text = text[:, : text_lengths.max()]
  1688. speech = speech[:, :speech_lengths.max()]
  1689. # 1. Encoder
  1690. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1691. intermediate_outs = None
  1692. if isinstance(encoder_out, tuple):
  1693. intermediate_outs = encoder_out[1]
  1694. encoder_out = encoder_out[0]
  1695. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1696. loss_ctc, cer_ctc = None, None
  1697. loss_pre = None
  1698. stats = dict()
  1699. # 1. CTC branch
  1700. if self.ctc_weight != 0.0:
  1701. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1702. encoder_out, encoder_out_lens, text, text_lengths
  1703. )
  1704. # Collect CTC branch stats
  1705. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1706. stats["cer_ctc"] = cer_ctc
  1707. # Intermediate CTC (optional)
  1708. loss_interctc = 0.0
  1709. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1710. for layer_idx, intermediate_out in intermediate_outs:
  1711. # we assume intermediate_out has the same length & padding
  1712. # as those of encoder_out
  1713. loss_ic, cer_ic = self._calc_ctc_loss(
  1714. intermediate_out, encoder_out_lens, text, text_lengths
  1715. )
  1716. loss_interctc = loss_interctc + loss_ic
  1717. # Collect Intermedaite CTC stats
  1718. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1719. loss_ic.detach() if loss_ic is not None else None
  1720. )
  1721. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1722. loss_interctc = loss_interctc / len(intermediate_outs)
  1723. # calculate whole encoder loss
  1724. loss_ctc = (
  1725. 1 - self.interctc_weight
  1726. ) * loss_ctc + self.interctc_weight * loss_interctc
  1727. # 2b. Attention decoder branch
  1728. if self.ctc_weight != 1.0:
  1729. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  1730. encoder_out, encoder_out_lens, text, text_lengths
  1731. )
  1732. # 3. CTC-Att loss definition
  1733. if self.ctc_weight == 0.0:
  1734. loss = loss_att + loss_pre * self.predictor_weight
  1735. elif self.ctc_weight == 1.0:
  1736. loss = loss_ctc
  1737. else:
  1738. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  1739. # Collect Attn branch stats
  1740. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1741. stats["acc"] = acc_att
  1742. stats["cer"] = cer_att
  1743. stats["wer"] = wer_att
  1744. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  1745. stats["loss"] = torch.clone(loss.detach())
  1746. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1747. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1748. return loss, stats, weight
  1749. def _sample_hot_word(self, ys_pad, ys_pad_lens):
  1750. hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
  1751. hw_lengths = [0] # this length is actually for indice, so -1
  1752. for i, length in enumerate(ys_pad_lens):
  1753. if length < 2:
  1754. continue
  1755. if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
  1756. # sample double hotword
  1757. _max_hw_length = min(self.max_hw_length, length // 2)
  1758. # first hotword
  1759. start1 = random.randint(0, length // 3)
  1760. end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
  1761. hw_tokens1 = ys_pad[i][start1:end1 + 1]
  1762. hw_lengths.append(len(hw_tokens1) - 1)
  1763. hw_list.append(hw_tokens1)
  1764. # second hotword
  1765. start2 = random.randint(end1 + 1, length - self.min_hw_length)
  1766. end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
  1767. min(length - 1, start2 + self.max_hw_length - 1))
  1768. hw_tokens2 = ys_pad[i][start2:end2 + 1]
  1769. hw_lengths.append(len(hw_tokens2) - 1)
  1770. hw_list.append(hw_tokens2)
  1771. continue
  1772. if random.random() < self.sample_rate:
  1773. if length == 2:
  1774. hw_tokens = ys_pad[i][:2]
  1775. hw_lengths.append(1)
  1776. hw_list.append(hw_tokens)
  1777. else:
  1778. start = random.randint(0, length - self.min_hw_length)
  1779. end = random.randint(min(length - 1, start + self.min_hw_length - 1),
  1780. min(length - 1, start + self.max_hw_length - 1)) + 1
  1781. # print(start, end)
  1782. hw_tokens = ys_pad[i][start:end]
  1783. hw_lengths.append(len(hw_tokens) - 1)
  1784. hw_list.append(hw_tokens)
  1785. # padding
  1786. hw_list_pad = pad_list(hw_list, 0)
  1787. if self.use_decoder_embedding:
  1788. hw_embed = self.decoder.embed(hw_list_pad)
  1789. else:
  1790. hw_embed = self.bias_embed(hw_list_pad)
  1791. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  1792. _ind = np.arange(0, len(hw_list)).tolist()
  1793. # update self.hotword_buffer, throw a part if oversize
  1794. selected = hw_embed[_ind, hw_lengths]
  1795. if self.target_buffer_length > 0:
  1796. _b = selected.shape[0]
  1797. if self.hotword_buffer is None:
  1798. self.hotword_buffer = selected
  1799. self.length_record.append(selected.shape[0])
  1800. self.current_buffer_length = _b
  1801. elif self.current_buffer_length + _b < self.target_buffer_length:
  1802. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1803. self.current_buffer_length += _b
  1804. selected = self.hotword_buffer
  1805. else:
  1806. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1807. random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
  1808. self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
  1809. selected = self.hotword_buffer
  1810. self.current_buffer_length = selected.shape[0]
  1811. return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
  1812. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
  1813. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  1814. ys_pad = ys_pad * tgt_mask[:, :, 0]
  1815. if self.share_embedding:
  1816. ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
  1817. else:
  1818. ys_pad_embed = self.decoder.embed(ys_pad)
  1819. with torch.no_grad():
  1820. decoder_outs = self.decoder(
  1821. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
  1822. )
  1823. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1824. pred_tokens = decoder_out.argmax(-1)
  1825. nonpad_positions = ys_pad.ne(self.ignore_id)
  1826. seq_lens = (nonpad_positions).sum(1)
  1827. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  1828. input_mask = torch.ones_like(nonpad_positions)
  1829. bsz, seq_len = ys_pad.size()
  1830. for li in range(bsz):
  1831. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  1832. if target_num > 0:
  1833. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  1834. input_mask = input_mask.eq(1)
  1835. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  1836. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  1837. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  1838. input_mask_expand_dim, 0)
  1839. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  1840. def _calc_att_loss(
  1841. self,
  1842. encoder_out: torch.Tensor,
  1843. encoder_out_lens: torch.Tensor,
  1844. ys_pad: torch.Tensor,
  1845. ys_pad_lens: torch.Tensor,
  1846. ):
  1847. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1848. encoder_out.device)
  1849. if self.predictor_bias == 1:
  1850. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1851. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1852. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
  1853. encoder_out_mask,
  1854. ignore_id=self.ignore_id)
  1855. # sample hot word
  1856. contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
  1857. # 0. sampler
  1858. decoder_out_1st = None
  1859. if self.sampling_ratio > 0.0:
  1860. if self.step_cur < 2:
  1861. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1862. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  1863. pre_acoustic_embeds, contextual_info)
  1864. else:
  1865. if self.step_cur < 2:
  1866. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1867. sematic_embeds = pre_acoustic_embeds
  1868. # 1. Forward decoder
  1869. decoder_outs = self.decoder(
  1870. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1871. )
  1872. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1873. if decoder_out_1st is None:
  1874. decoder_out_1st = decoder_out
  1875. # 2. Compute attention loss
  1876. loss_att = self.criterion_att(decoder_out, ys_pad)
  1877. acc_att = th_accuracy(
  1878. decoder_out_1st.view(-1, self.vocab_size),
  1879. ys_pad,
  1880. ignore_label=self.ignore_id,
  1881. )
  1882. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1883. # Compute cer/wer using attention-decoder
  1884. if self.training or self.error_calculator is None:
  1885. cer_att, wer_att = None, None
  1886. else:
  1887. ys_hat = decoder_out_1st.argmax(dim=-1)
  1888. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1889. return loss_att, acc_att, cer_att, wer_att, loss_pre
  1890. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
  1891. if hw_list is None:
  1892. # default hotword list
  1893. hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
  1894. hw_list_pad = pad_list(hw_list, 0)
  1895. if self.use_decoder_embedding:
  1896. hw_embed = self.decoder.embed(hw_list_pad)
  1897. else:
  1898. hw_embed = self.bias_embed(hw_list_pad)
  1899. _, (h_n, _) = self.bias_encoder(hw_embed)
  1900. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1901. else:
  1902. hw_lengths = [len(i) for i in hw_list]
  1903. hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
  1904. if self.use_decoder_embedding:
  1905. hw_embed = self.decoder.embed(hw_list_pad)
  1906. else:
  1907. hw_embed = self.bias_embed(hw_list_pad)
  1908. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
  1909. enforce_sorted=False)
  1910. _, (h_n, _) = self.bias_encoder(hw_embed)
  1911. # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
  1912. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1913. decoder_outs = self.decoder(
  1914. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1915. )
  1916. decoder_out = decoder_outs[0]
  1917. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  1918. return decoder_out, ys_pad_lens
  1919. def gen_clas_tf2torch_map_dict(self):
  1920. tensor_name_prefix_torch = "bias_encoder"
  1921. tensor_name_prefix_tf = "seq2seq/clas_charrnn"
  1922. tensor_name_prefix_torch_emb = "bias_embed"
  1923. tensor_name_prefix_tf_emb = "seq2seq"
  1924. map_dict_local = {
  1925. # in lstm
  1926. "{}.weight_ih_l0".format(tensor_name_prefix_torch):
  1927. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1928. "squeeze": None,
  1929. "transpose": (1, 0),
  1930. "slice": (0, 512),
  1931. "unit_k": 512,
  1932. }, # (1024, 2048),(2048,512)
  1933. "{}.weight_hh_l0".format(tensor_name_prefix_torch):
  1934. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1935. "squeeze": None,
  1936. "transpose": (1, 0),
  1937. "slice": (512, 1024),
  1938. "unit_k": 512,
  1939. }, # (1024, 2048),(2048,512)
  1940. "{}.bias_ih_l0".format(tensor_name_prefix_torch):
  1941. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1942. "squeeze": None,
  1943. "transpose": None,
  1944. "scale": 0.5,
  1945. "unit_b": 512,
  1946. }, # (2048,),(2048,)
  1947. "{}.bias_hh_l0".format(tensor_name_prefix_torch):
  1948. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1949. "squeeze": None,
  1950. "transpose": None,
  1951. "scale": 0.5,
  1952. "unit_b": 512,
  1953. }, # (2048,),(2048,)
  1954. # in embed
  1955. "{}.weight".format(tensor_name_prefix_torch_emb):
  1956. {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
  1957. "squeeze": None,
  1958. "transpose": None,
  1959. }, # (4235,256),(4235,256)
  1960. }
  1961. return map_dict_local
  1962. def clas_convert_tf2torch(self,
  1963. var_dict_tf,
  1964. var_dict_torch):
  1965. map_dict = self.gen_clas_tf2torch_map_dict()
  1966. var_dict_torch_update = dict()
  1967. for name in sorted(var_dict_torch.keys(), reverse=False):
  1968. names = name.split('.')
  1969. if names[0] == "bias_encoder":
  1970. name_q = name
  1971. if name_q in map_dict.keys():
  1972. name_v = map_dict[name_q]["name"]
  1973. name_tf = name_v
  1974. data_tf = var_dict_tf[name_tf]
  1975. if map_dict[name_q].get("unit_k") is not None:
  1976. dim = map_dict[name_q]["unit_k"]
  1977. i = data_tf[:, 0:dim].copy()
  1978. f = data_tf[:, dim:2 * dim].copy()
  1979. o = data_tf[:, 2 * dim:3 * dim].copy()
  1980. g = data_tf[:, 3 * dim:4 * dim].copy()
  1981. data_tf = np.concatenate([i, o, f, g], axis=1)
  1982. if map_dict[name_q].get("unit_b") is not None:
  1983. dim = map_dict[name_q]["unit_b"]
  1984. i = data_tf[0:dim].copy()
  1985. f = data_tf[dim:2 * dim].copy()
  1986. o = data_tf[2 * dim:3 * dim].copy()
  1987. g = data_tf[3 * dim:4 * dim].copy()
  1988. data_tf = np.concatenate([i, o, f, g], axis=0)
  1989. if map_dict[name_q]["squeeze"] is not None:
  1990. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1991. if map_dict[name_q].get("slice") is not None:
  1992. data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
  1993. if map_dict[name_q].get("scale") is not None:
  1994. data_tf = data_tf * map_dict[name_q]["scale"]
  1995. if map_dict[name_q]["transpose"] is not None:
  1996. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1997. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1998. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1999. var_dict_torch[
  2000. name].size(),
  2001. data_tf.size())
  2002. var_dict_torch_update[name] = data_tf
  2003. logging.info(
  2004. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  2005. var_dict_tf[name_tf].shape))
  2006. elif names[0] == "bias_embed":
  2007. name_tf = map_dict[name]["name"]
  2008. data_tf = var_dict_tf[name_tf]
  2009. if map_dict[name]["squeeze"] is not None:
  2010. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  2011. if map_dict[name]["transpose"] is not None:
  2012. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  2013. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  2014. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  2015. var_dict_torch[
  2016. name].size(),
  2017. data_tf.size())
  2018. var_dict_torch_update[name] = data_tf
  2019. logging.info(
  2020. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  2021. var_dict_tf[name_tf].shape))
  2022. return var_dict_torch_update