e2e_asr_paraformer.py 72 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764
  1. import logging
  2. from contextlib import contextmanager
  3. from distutils.version import LooseVersion
  4. from typing import Dict
  5. from typing import List
  6. from typing import Optional
  7. from typing import Tuple
  8. from typing import Union
  9. import torch
  10. import random
  11. import numpy as np
  12. from typeguard import check_argument_types
  13. from funasr.layers.abs_normalize import AbsNormalize
  14. from funasr.losses.label_smoothing_loss import (
  15. LabelSmoothingLoss, # noqa: H301
  16. )
  17. from funasr.models.ctc import CTC
  18. from funasr.models.decoder.abs_decoder import AbsDecoder
  19. from funasr.models.e2e_asr_common import ErrorCalculator
  20. from funasr.models.encoder.abs_encoder import AbsEncoder
  21. from funasr.models.frontend.abs_frontend import AbsFrontend
  22. from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
  23. from funasr.models.predictor.cif import mae_loss
  24. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  25. from funasr.models.specaug.abs_specaug import AbsSpecAug
  26. from funasr.modules.add_sos_eos import add_sos_eos
  27. from funasr.modules.nets_utils import make_pad_mask, pad_list
  28. from funasr.modules.nets_utils import th_accuracy
  29. from funasr.torch_utils.device_funcs import force_gatherable
  30. from funasr.train.abs_espnet_model import AbsESPnetModel
  31. from funasr.models.predictor.cif import CifPredictorV3
  32. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  33. from torch.cuda.amp import autocast
  34. else:
  35. # Nothing to do if torch<1.6.0
  36. @contextmanager
  37. def autocast(enabled=True):
  38. yield
  39. class Paraformer(AbsESPnetModel):
  40. """
  41. Author: Speech Lab of DAMO Academy, Alibaba Group
  42. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  43. https://arxiv.org/abs/2206.08317
  44. """
  45. def __init__(
  46. self,
  47. vocab_size: int,
  48. token_list: Union[Tuple[str, ...], List[str]],
  49. frontend: Optional[AbsFrontend],
  50. specaug: Optional[AbsSpecAug],
  51. normalize: Optional[AbsNormalize],
  52. preencoder: Optional[AbsPreEncoder],
  53. encoder: AbsEncoder,
  54. postencoder: Optional[AbsPostEncoder],
  55. decoder: AbsDecoder,
  56. ctc: CTC,
  57. ctc_weight: float = 0.5,
  58. interctc_weight: float = 0.0,
  59. ignore_id: int = -1,
  60. blank_id: int = 0,
  61. sos: int = 1,
  62. eos: int = 2,
  63. lsm_weight: float = 0.0,
  64. length_normalized_loss: bool = False,
  65. report_cer: bool = True,
  66. report_wer: bool = True,
  67. sym_space: str = "<space>",
  68. sym_blank: str = "<blank>",
  69. extract_feats_in_collect_stats: bool = True,
  70. predictor=None,
  71. predictor_weight: float = 0.0,
  72. predictor_bias: int = 0,
  73. sampling_ratio: float = 0.2,
  74. share_embedding: bool = False,
  75. ):
  76. assert check_argument_types()
  77. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  78. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  79. super().__init__()
  80. # note that eos is the same as sos (equivalent ID)
  81. self.blank_id = blank_id
  82. self.sos = vocab_size - 1 if sos is None else sos
  83. self.eos = vocab_size - 1 if eos is None else eos
  84. self.vocab_size = vocab_size
  85. self.ignore_id = ignore_id
  86. self.ctc_weight = ctc_weight
  87. self.interctc_weight = interctc_weight
  88. self.token_list = token_list.copy()
  89. self.frontend = frontend
  90. self.specaug = specaug
  91. self.normalize = normalize
  92. self.preencoder = preencoder
  93. self.postencoder = postencoder
  94. self.encoder = encoder
  95. if not hasattr(self.encoder, "interctc_use_conditioning"):
  96. self.encoder.interctc_use_conditioning = False
  97. if self.encoder.interctc_use_conditioning:
  98. self.encoder.conditioning_layer = torch.nn.Linear(
  99. vocab_size, self.encoder.output_size()
  100. )
  101. self.error_calculator = None
  102. if ctc_weight == 1.0:
  103. self.decoder = None
  104. else:
  105. self.decoder = decoder
  106. self.criterion_att = LabelSmoothingLoss(
  107. size=vocab_size,
  108. padding_idx=ignore_id,
  109. smoothing=lsm_weight,
  110. normalize_length=length_normalized_loss,
  111. )
  112. if report_cer or report_wer:
  113. self.error_calculator = ErrorCalculator(
  114. token_list, sym_space, sym_blank, report_cer, report_wer
  115. )
  116. if ctc_weight == 0.0:
  117. self.ctc = None
  118. else:
  119. self.ctc = ctc
  120. self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
  121. self.predictor = predictor
  122. self.predictor_weight = predictor_weight
  123. self.predictor_bias = predictor_bias
  124. self.sampling_ratio = sampling_ratio
  125. self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
  126. self.step_cur = 0
  127. self.share_embedding = share_embedding
  128. if self.share_embedding:
  129. self.decoder.embed = None
  130. def forward(
  131. self,
  132. speech: torch.Tensor,
  133. speech_lengths: torch.Tensor,
  134. text: torch.Tensor,
  135. text_lengths: torch.Tensor,
  136. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  137. """Frontend + Encoder + Decoder + Calc loss
  138. Args:
  139. speech: (Batch, Length, ...)
  140. speech_lengths: (Batch, )
  141. text: (Batch, Length)
  142. text_lengths: (Batch,)
  143. """
  144. assert text_lengths.dim() == 1, text_lengths.shape
  145. # Check that batch_size is unified
  146. assert (
  147. speech.shape[0]
  148. == speech_lengths.shape[0]
  149. == text.shape[0]
  150. == text_lengths.shape[0]
  151. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  152. batch_size = speech.shape[0]
  153. self.step_cur += 1
  154. # for data-parallel
  155. text = text[:, : text_lengths.max()]
  156. speech = speech[:, :speech_lengths.max()]
  157. # 1. Encoder
  158. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  159. intermediate_outs = None
  160. if isinstance(encoder_out, tuple):
  161. intermediate_outs = encoder_out[1]
  162. encoder_out = encoder_out[0]
  163. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  164. loss_ctc, cer_ctc = None, None
  165. loss_pre = None
  166. stats = dict()
  167. # 1. CTC branch
  168. if self.ctc_weight != 0.0:
  169. loss_ctc, cer_ctc = self._calc_ctc_loss(
  170. encoder_out, encoder_out_lens, text, text_lengths
  171. )
  172. # Collect CTC branch stats
  173. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  174. stats["cer_ctc"] = cer_ctc
  175. # Intermediate CTC (optional)
  176. loss_interctc = 0.0
  177. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  178. for layer_idx, intermediate_out in intermediate_outs:
  179. # we assume intermediate_out has the same length & padding
  180. # as those of encoder_out
  181. loss_ic, cer_ic = self._calc_ctc_loss(
  182. intermediate_out, encoder_out_lens, text, text_lengths
  183. )
  184. loss_interctc = loss_interctc + loss_ic
  185. # Collect Intermedaite CTC stats
  186. stats["loss_interctc_layer{}".format(layer_idx)] = (
  187. loss_ic.detach() if loss_ic is not None else None
  188. )
  189. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  190. loss_interctc = loss_interctc / len(intermediate_outs)
  191. # calculate whole encoder loss
  192. loss_ctc = (
  193. 1 - self.interctc_weight
  194. ) * loss_ctc + self.interctc_weight * loss_interctc
  195. # 2b. Attention decoder branch
  196. if self.ctc_weight != 1.0:
  197. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  198. encoder_out, encoder_out_lens, text, text_lengths
  199. )
  200. # 3. CTC-Att loss definition
  201. if self.ctc_weight == 0.0:
  202. loss = loss_att + loss_pre * self.predictor_weight
  203. elif self.ctc_weight == 1.0:
  204. loss = loss_ctc
  205. else:
  206. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  207. # Collect Attn branch stats
  208. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  209. stats["acc"] = acc_att
  210. stats["cer"] = cer_att
  211. stats["wer"] = wer_att
  212. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  213. stats["loss"] = torch.clone(loss.detach())
  214. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  215. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  216. return loss, stats, weight
  217. def collect_feats(
  218. self,
  219. speech: torch.Tensor,
  220. speech_lengths: torch.Tensor,
  221. text: torch.Tensor,
  222. text_lengths: torch.Tensor,
  223. ) -> Dict[str, torch.Tensor]:
  224. if self.extract_feats_in_collect_stats:
  225. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  226. else:
  227. # Generate dummy stats if extract_feats_in_collect_stats is False
  228. logging.warning(
  229. "Generating dummy stats for feats and feats_lengths, "
  230. "because encoder_conf.extract_feats_in_collect_stats is "
  231. f"{self.extract_feats_in_collect_stats}"
  232. )
  233. feats, feats_lengths = speech, speech_lengths
  234. return {"feats": feats, "feats_lengths": feats_lengths}
  235. def encode(
  236. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  237. ) -> Tuple[torch.Tensor, torch.Tensor]:
  238. """Frontend + Encoder. Note that this method is used by asr_inference.py
  239. Args:
  240. speech: (Batch, Length, ...)
  241. speech_lengths: (Batch, )
  242. """
  243. with autocast(False):
  244. # 1. Extract feats
  245. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  246. # 2. Data augmentation
  247. if self.specaug is not None and self.training:
  248. feats, feats_lengths = self.specaug(feats, feats_lengths)
  249. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  250. if self.normalize is not None:
  251. feats, feats_lengths = self.normalize(feats, feats_lengths)
  252. # Pre-encoder, e.g. used for raw input data
  253. if self.preencoder is not None:
  254. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  255. # 4. Forward encoder
  256. # feats: (Batch, Length, Dim)
  257. # -> encoder_out: (Batch, Length2, Dim2)
  258. if self.encoder.interctc_use_conditioning:
  259. encoder_out, encoder_out_lens, _ = self.encoder(
  260. feats, feats_lengths, ctc=self.ctc
  261. )
  262. else:
  263. encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
  264. intermediate_outs = None
  265. if isinstance(encoder_out, tuple):
  266. intermediate_outs = encoder_out[1]
  267. encoder_out = encoder_out[0]
  268. # Post-encoder, e.g. NLU
  269. if self.postencoder is not None:
  270. encoder_out, encoder_out_lens = self.postencoder(
  271. encoder_out, encoder_out_lens
  272. )
  273. assert encoder_out.size(0) == speech.size(0), (
  274. encoder_out.size(),
  275. speech.size(0),
  276. )
  277. assert encoder_out.size(1) <= encoder_out_lens.max(), (
  278. encoder_out.size(),
  279. encoder_out_lens.max(),
  280. )
  281. if intermediate_outs is not None:
  282. return (encoder_out, intermediate_outs), encoder_out_lens
  283. return encoder_out, encoder_out_lens
  284. def calc_predictor(self, encoder_out, encoder_out_lens):
  285. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  286. encoder_out.device)
  287. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
  288. ignore_id=self.ignore_id)
  289. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  290. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
  291. decoder_outs = self.decoder(
  292. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  293. )
  294. decoder_out = decoder_outs[0]
  295. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  296. return decoder_out, ys_pad_lens
  297. def _extract_feats(
  298. self, speech: torch.Tensor, speech_lengths: torch.Tensor
  299. ) -> Tuple[torch.Tensor, torch.Tensor]:
  300. assert speech_lengths.dim() == 1, speech_lengths.shape
  301. # for data-parallel
  302. speech = speech[:, : speech_lengths.max()]
  303. if self.frontend is not None:
  304. # Frontend
  305. # e.g. STFT and Feature extract
  306. # data_loader may send time-domain signal in this case
  307. # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
  308. feats, feats_lengths = self.frontend(speech, speech_lengths)
  309. else:
  310. # No frontend and no feature extract
  311. feats, feats_lengths = speech, speech_lengths
  312. return feats, feats_lengths
  313. def nll(
  314. self,
  315. encoder_out: torch.Tensor,
  316. encoder_out_lens: torch.Tensor,
  317. ys_pad: torch.Tensor,
  318. ys_pad_lens: torch.Tensor,
  319. ) -> torch.Tensor:
  320. """Compute negative log likelihood(nll) from transformer-decoder
  321. Normally, this function is called in batchify_nll.
  322. Args:
  323. encoder_out: (Batch, Length, Dim)
  324. encoder_out_lens: (Batch,)
  325. ys_pad: (Batch, Length)
  326. ys_pad_lens: (Batch,)
  327. """
  328. ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  329. ys_in_lens = ys_pad_lens + 1
  330. # 1. Forward decoder
  331. decoder_out, _ = self.decoder(
  332. encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
  333. ) # [batch, seqlen, dim]
  334. batch_size = decoder_out.size(0)
  335. decoder_num_class = decoder_out.size(2)
  336. # nll: negative log-likelihood
  337. nll = torch.nn.functional.cross_entropy(
  338. decoder_out.view(-1, decoder_num_class),
  339. ys_out_pad.view(-1),
  340. ignore_index=self.ignore_id,
  341. reduction="none",
  342. )
  343. nll = nll.view(batch_size, -1)
  344. nll = nll.sum(dim=1)
  345. assert nll.size(0) == batch_size
  346. return nll
  347. def batchify_nll(
  348. self,
  349. encoder_out: torch.Tensor,
  350. encoder_out_lens: torch.Tensor,
  351. ys_pad: torch.Tensor,
  352. ys_pad_lens: torch.Tensor,
  353. batch_size: int = 100,
  354. ):
  355. """Compute negative log likelihood(nll) from transformer-decoder
  356. To avoid OOM, this fuction seperate the input into batches.
  357. Then call nll for each batch and combine and return results.
  358. Args:
  359. encoder_out: (Batch, Length, Dim)
  360. encoder_out_lens: (Batch,)
  361. ys_pad: (Batch, Length)
  362. ys_pad_lens: (Batch,)
  363. batch_size: int, samples each batch contain when computing nll,
  364. you may change this to avoid OOM or increase
  365. GPU memory usage
  366. """
  367. total_num = encoder_out.size(0)
  368. if total_num <= batch_size:
  369. nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  370. else:
  371. nll = []
  372. start_idx = 0
  373. while True:
  374. end_idx = min(start_idx + batch_size, total_num)
  375. batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
  376. batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
  377. batch_ys_pad = ys_pad[start_idx:end_idx, :]
  378. batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
  379. batch_nll = self.nll(
  380. batch_encoder_out,
  381. batch_encoder_out_lens,
  382. batch_ys_pad,
  383. batch_ys_pad_lens,
  384. )
  385. nll.append(batch_nll)
  386. start_idx = end_idx
  387. if start_idx == total_num:
  388. break
  389. nll = torch.cat(nll)
  390. assert nll.size(0) == total_num
  391. return nll
  392. def _calc_att_loss(
  393. self,
  394. encoder_out: torch.Tensor,
  395. encoder_out_lens: torch.Tensor,
  396. ys_pad: torch.Tensor,
  397. ys_pad_lens: torch.Tensor,
  398. ):
  399. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  400. encoder_out.device)
  401. if self.predictor_bias == 1:
  402. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  403. ys_pad_lens = ys_pad_lens + self.predictor_bias
  404. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  405. ignore_id=self.ignore_id)
  406. # 0. sampler
  407. decoder_out_1st = None
  408. if self.sampling_ratio > 0.0:
  409. if self.step_cur < 2:
  410. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  411. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  412. pre_acoustic_embeds)
  413. else:
  414. if self.step_cur < 2:
  415. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  416. sematic_embeds = pre_acoustic_embeds
  417. # 1. Forward decoder
  418. decoder_outs = self.decoder(
  419. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  420. )
  421. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  422. if decoder_out_1st is None:
  423. decoder_out_1st = decoder_out
  424. # 2. Compute attention loss
  425. loss_att = self.criterion_att(decoder_out, ys_pad)
  426. acc_att = th_accuracy(
  427. decoder_out_1st.view(-1, self.vocab_size),
  428. ys_pad,
  429. ignore_label=self.ignore_id,
  430. )
  431. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  432. # Compute cer/wer using attention-decoder
  433. if self.training or self.error_calculator is None:
  434. cer_att, wer_att = None, None
  435. else:
  436. ys_hat = decoder_out_1st.argmax(dim=-1)
  437. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  438. return loss_att, acc_att, cer_att, wer_att, loss_pre
  439. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
  440. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  441. ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
  442. if self.share_embedding:
  443. ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
  444. else:
  445. ys_pad_embed = self.decoder.embed(ys_pad_masked)
  446. with torch.no_grad():
  447. decoder_outs = self.decoder(
  448. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
  449. )
  450. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  451. pred_tokens = decoder_out.argmax(-1)
  452. nonpad_positions = ys_pad.ne(self.ignore_id)
  453. seq_lens = (nonpad_positions).sum(1)
  454. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  455. input_mask = torch.ones_like(nonpad_positions)
  456. bsz, seq_len = ys_pad.size()
  457. for li in range(bsz):
  458. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  459. if target_num > 0:
  460. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  461. input_mask = input_mask.eq(1)
  462. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  463. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  464. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  465. input_mask_expand_dim, 0)
  466. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  467. def _calc_ctc_loss(
  468. self,
  469. encoder_out: torch.Tensor,
  470. encoder_out_lens: torch.Tensor,
  471. ys_pad: torch.Tensor,
  472. ys_pad_lens: torch.Tensor,
  473. ):
  474. # Calc CTC loss
  475. loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
  476. # Calc CER using CTC
  477. cer_ctc = None
  478. if not self.training and self.error_calculator is not None:
  479. ys_hat = self.ctc.argmax(encoder_out).data
  480. cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
  481. return loss_ctc, cer_ctc
  482. class ParaformerOnline(Paraformer):
  483. """
  484. Author: Speech Lab, Alibaba Group, China
  485. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  486. https://arxiv.org/abs/2206.08317
  487. """
  488. def __init__(
  489. self, *args, **kwargs,
  490. ):
  491. super().__init__(*args, **kwargs)
  492. def forward(
  493. self,
  494. speech: torch.Tensor,
  495. speech_lengths: torch.Tensor,
  496. text: torch.Tensor,
  497. text_lengths: torch.Tensor,
  498. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  499. """Frontend + Encoder + Decoder + Calc loss
  500. Args:
  501. speech: (Batch, Length, ...)
  502. speech_lengths: (Batch, )
  503. text: (Batch, Length)
  504. text_lengths: (Batch,)
  505. """
  506. assert text_lengths.dim() == 1, text_lengths.shape
  507. # Check that batch_size is unified
  508. assert (
  509. speech.shape[0]
  510. == speech_lengths.shape[0]
  511. == text.shape[0]
  512. == text_lengths.shape[0]
  513. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  514. batch_size = speech.shape[0]
  515. self.step_cur += 1
  516. # for data-parallel
  517. text = text[:, : text_lengths.max()]
  518. speech = speech[:, :speech_lengths.max()]
  519. # 1. Encoder
  520. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  521. intermediate_outs = None
  522. if isinstance(encoder_out, tuple):
  523. intermediate_outs = encoder_out[1]
  524. encoder_out = encoder_out[0]
  525. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  526. loss_ctc, cer_ctc = None, None
  527. loss_pre = None
  528. stats = dict()
  529. # 1. CTC branch
  530. if self.ctc_weight != 0.0:
  531. loss_ctc, cer_ctc = self._calc_ctc_loss(
  532. encoder_out, encoder_out_lens, text, text_lengths
  533. )
  534. # Collect CTC branch stats
  535. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  536. stats["cer_ctc"] = cer_ctc
  537. # Intermediate CTC (optional)
  538. loss_interctc = 0.0
  539. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  540. for layer_idx, intermediate_out in intermediate_outs:
  541. # we assume intermediate_out has the same length & padding
  542. # as those of encoder_out
  543. loss_ic, cer_ic = self._calc_ctc_loss(
  544. intermediate_out, encoder_out_lens, text, text_lengths
  545. )
  546. loss_interctc = loss_interctc + loss_ic
  547. # Collect Intermedaite CTC stats
  548. stats["loss_interctc_layer{}".format(layer_idx)] = (
  549. loss_ic.detach() if loss_ic is not None else None
  550. )
  551. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  552. loss_interctc = loss_interctc / len(intermediate_outs)
  553. # calculate whole encoder loss
  554. loss_ctc = (
  555. 1 - self.interctc_weight
  556. ) * loss_ctc + self.interctc_weight * loss_interctc
  557. # 2b. Attention decoder branch
  558. if self.ctc_weight != 1.0:
  559. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  560. encoder_out, encoder_out_lens, text, text_lengths
  561. )
  562. # 3. CTC-Att loss definition
  563. if self.ctc_weight == 0.0:
  564. loss = loss_att + loss_pre * self.predictor_weight
  565. elif self.ctc_weight == 1.0:
  566. loss = loss_ctc
  567. else:
  568. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  569. # Collect Attn branch stats
  570. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  571. stats["acc"] = acc_att
  572. stats["cer"] = cer_att
  573. stats["wer"] = wer_att
  574. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  575. stats["loss"] = torch.clone(loss.detach())
  576. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  577. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  578. return loss, stats, weight
  579. def encode_chunk(
  580. self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
  581. ) -> Tuple[torch.Tensor, torch.Tensor]:
  582. """Frontend + Encoder. Note that this method is used by asr_inference.py
  583. Args:
  584. speech: (Batch, Length, ...)
  585. speech_lengths: (Batch, )
  586. """
  587. with autocast(False):
  588. # 1. Extract feats
  589. feats, feats_lengths = self._extract_feats(speech, speech_lengths)
  590. # 2. Data augmentation
  591. if self.specaug is not None and self.training:
  592. feats, feats_lengths = self.specaug(feats, feats_lengths)
  593. # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
  594. if self.normalize is not None:
  595. feats, feats_lengths = self.normalize(feats, feats_lengths)
  596. # Pre-encoder, e.g. used for raw input data
  597. if self.preencoder is not None:
  598. feats, feats_lengths = self.preencoder(feats, feats_lengths)
  599. # 4. Forward encoder
  600. # feats: (Batch, Length, Dim)
  601. # -> encoder_out: (Batch, Length2, Dim2)
  602. if self.encoder.interctc_use_conditioning:
  603. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
  604. feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
  605. )
  606. else:
  607. encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
  608. intermediate_outs = None
  609. if isinstance(encoder_out, tuple):
  610. intermediate_outs = encoder_out[1]
  611. encoder_out = encoder_out[0]
  612. # Post-encoder, e.g. NLU
  613. if self.postencoder is not None:
  614. encoder_out, encoder_out_lens = self.postencoder(
  615. encoder_out, encoder_out_lens
  616. )
  617. if intermediate_outs is not None:
  618. return (encoder_out, intermediate_outs), encoder_out_lens
  619. return encoder_out, torch.tensor([encoder_out.size(1)])
  620. def calc_predictor_chunk(self, encoder_out, cache=None):
  621. pre_acoustic_embeds, pre_token_length = \
  622. self.predictor.forward_chunk(encoder_out, cache["encoder"])
  623. return pre_acoustic_embeds, pre_token_length
  624. def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
  625. decoder_outs = self.decoder.forward_chunk(
  626. encoder_out, sematic_embeds, cache["decoder"]
  627. )
  628. decoder_out = decoder_outs
  629. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  630. return decoder_out
  631. class ParaformerBert(Paraformer):
  632. """
  633. Author: Speech Lab of DAMO Academy, Alibaba Group
  634. Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
  635. """
  636. def __init__(
  637. self,
  638. vocab_size: int,
  639. token_list: Union[Tuple[str, ...], List[str]],
  640. frontend: Optional[AbsFrontend],
  641. specaug: Optional[AbsSpecAug],
  642. normalize: Optional[AbsNormalize],
  643. preencoder: Optional[AbsPreEncoder],
  644. encoder: AbsEncoder,
  645. postencoder: Optional[AbsPostEncoder],
  646. decoder: AbsDecoder,
  647. ctc: CTC,
  648. ctc_weight: float = 0.5,
  649. interctc_weight: float = 0.0,
  650. ignore_id: int = -1,
  651. blank_id: int = 0,
  652. sos: int = 1,
  653. eos: int = 2,
  654. lsm_weight: float = 0.0,
  655. length_normalized_loss: bool = False,
  656. report_cer: bool = True,
  657. report_wer: bool = True,
  658. sym_space: str = "<space>",
  659. sym_blank: str = "<blank>",
  660. extract_feats_in_collect_stats: bool = True,
  661. predictor=None,
  662. predictor_weight: float = 0.0,
  663. predictor_bias: int = 0,
  664. sampling_ratio: float = 0.2,
  665. embeds_id: int = 2,
  666. embeds_loss_weight: float = 0.0,
  667. embed_dims: int = 768,
  668. ):
  669. assert check_argument_types()
  670. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  671. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  672. super().__init__(
  673. vocab_size=vocab_size,
  674. token_list=token_list,
  675. frontend=frontend,
  676. specaug=specaug,
  677. normalize=normalize,
  678. preencoder=preencoder,
  679. encoder=encoder,
  680. postencoder=postencoder,
  681. decoder=decoder,
  682. ctc=ctc,
  683. ctc_weight=ctc_weight,
  684. interctc_weight=interctc_weight,
  685. ignore_id=ignore_id,
  686. blank_id=blank_id,
  687. sos=sos,
  688. eos=eos,
  689. lsm_weight=lsm_weight,
  690. length_normalized_loss=length_normalized_loss,
  691. report_cer=report_cer,
  692. report_wer=report_wer,
  693. sym_space=sym_space,
  694. sym_blank=sym_blank,
  695. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  696. predictor=predictor,
  697. predictor_weight=predictor_weight,
  698. predictor_bias=predictor_bias,
  699. sampling_ratio=sampling_ratio,
  700. )
  701. self.decoder.embeds_id = embeds_id
  702. decoder_attention_dim = self.decoder.attention_dim
  703. self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
  704. self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
  705. self.embeds_loss_weight = embeds_loss_weight
  706. self.length_normalized_loss = length_normalized_loss
  707. def _calc_embed_loss(self,
  708. ys_pad: torch.Tensor,
  709. ys_pad_lens: torch.Tensor,
  710. embed: torch.Tensor = None,
  711. embed_lengths: torch.Tensor = None,
  712. embeds_outputs: torch.Tensor = None,
  713. ):
  714. embeds_outputs = self.pro_nn(embeds_outputs)
  715. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  716. embeds_outputs *= tgt_mask # b x l x d
  717. embed *= tgt_mask # b x l x d
  718. cos_loss = 1.0 - self.cos(embeds_outputs, embed)
  719. cos_loss *= tgt_mask.squeeze(2)
  720. if self.length_normalized_loss:
  721. token_num_total = torch.sum(tgt_mask)
  722. else:
  723. token_num_total = tgt_mask.size()[0]
  724. cos_loss_total = torch.sum(cos_loss)
  725. cos_loss = cos_loss_total / token_num_total
  726. # print("cos_loss: {}".format(cos_loss))
  727. return cos_loss
  728. def _calc_att_loss(
  729. self,
  730. encoder_out: torch.Tensor,
  731. encoder_out_lens: torch.Tensor,
  732. ys_pad: torch.Tensor,
  733. ys_pad_lens: torch.Tensor,
  734. ):
  735. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  736. encoder_out.device)
  737. if self.predictor_bias == 1:
  738. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  739. ys_pad_lens = ys_pad_lens + self.predictor_bias
  740. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  741. ignore_id=self.ignore_id)
  742. # 0. sampler
  743. decoder_out_1st = None
  744. if self.sampling_ratio > 0.0:
  745. if self.step_cur < 2:
  746. logging.info(
  747. "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  748. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  749. pre_acoustic_embeds)
  750. else:
  751. if self.step_cur < 2:
  752. logging.info(
  753. "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  754. sematic_embeds = pre_acoustic_embeds
  755. # 1. Forward decoder
  756. decoder_outs = self.decoder(
  757. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  758. )
  759. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  760. embeds_outputs = None
  761. if len(decoder_outs) > 2:
  762. embeds_outputs = decoder_outs[2]
  763. if decoder_out_1st is None:
  764. decoder_out_1st = decoder_out
  765. # 2. Compute attention loss
  766. loss_att = self.criterion_att(decoder_out, ys_pad)
  767. acc_att = th_accuracy(
  768. decoder_out_1st.view(-1, self.vocab_size),
  769. ys_pad,
  770. ignore_label=self.ignore_id,
  771. )
  772. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  773. # Compute cer/wer using attention-decoder
  774. if self.training or self.error_calculator is None:
  775. cer_att, wer_att = None, None
  776. else:
  777. ys_hat = decoder_out_1st.argmax(dim=-1)
  778. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  779. return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
  780. def forward(
  781. self,
  782. speech: torch.Tensor,
  783. speech_lengths: torch.Tensor,
  784. text: torch.Tensor,
  785. text_lengths: torch.Tensor,
  786. embed: torch.Tensor = None,
  787. embed_lengths: torch.Tensor = None,
  788. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  789. """Frontend + Encoder + Decoder + Calc loss
  790. Args:
  791. speech: (Batch, Length, ...)
  792. speech_lengths: (Batch, )
  793. text: (Batch, Length)
  794. text_lengths: (Batch,)
  795. """
  796. assert text_lengths.dim() == 1, text_lengths.shape
  797. # Check that batch_size is unified
  798. assert (
  799. speech.shape[0]
  800. == speech_lengths.shape[0]
  801. == text.shape[0]
  802. == text_lengths.shape[0]
  803. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  804. batch_size = speech.shape[0]
  805. self.step_cur += 1
  806. # for data-parallel
  807. text = text[:, : text_lengths.max()]
  808. speech = speech[:, :speech_lengths.max(), :]
  809. if embed is not None:
  810. embed = embed[:, :embed_lengths.max(), :]
  811. # 1. Encoder
  812. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  813. intermediate_outs = None
  814. if isinstance(encoder_out, tuple):
  815. intermediate_outs = encoder_out[1]
  816. encoder_out = encoder_out[0]
  817. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  818. loss_ctc, cer_ctc = None, None
  819. loss_pre = 0.0
  820. cos_loss = 0.0
  821. stats = dict()
  822. # 1. CTC branch
  823. if self.ctc_weight != 0.0:
  824. loss_ctc, cer_ctc = self._calc_ctc_loss(
  825. encoder_out, encoder_out_lens, text, text_lengths
  826. )
  827. # Collect CTC branch stats
  828. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  829. stats["cer_ctc"] = cer_ctc
  830. # Intermediate CTC (optional)
  831. loss_interctc = 0.0
  832. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  833. for layer_idx, intermediate_out in intermediate_outs:
  834. # we assume intermediate_out has the same length & padding
  835. # as those of encoder_out
  836. loss_ic, cer_ic = self._calc_ctc_loss(
  837. intermediate_out, encoder_out_lens, text, text_lengths
  838. )
  839. loss_interctc = loss_interctc + loss_ic
  840. # Collect Intermedaite CTC stats
  841. stats["loss_interctc_layer{}".format(layer_idx)] = (
  842. loss_ic.detach() if loss_ic is not None else None
  843. )
  844. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  845. loss_interctc = loss_interctc / len(intermediate_outs)
  846. # calculate whole encoder loss
  847. loss_ctc = (
  848. 1 - self.interctc_weight
  849. ) * loss_ctc + self.interctc_weight * loss_interctc
  850. # 2b. Attention decoder branch
  851. if self.ctc_weight != 1.0:
  852. loss_ret = self._calc_att_loss(
  853. encoder_out, encoder_out_lens, text, text_lengths
  854. )
  855. loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
  856. loss_ret[4]
  857. embeds_outputs = None
  858. if len(loss_ret) > 5:
  859. embeds_outputs = loss_ret[5]
  860. if embeds_outputs is not None:
  861. cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
  862. # 3. CTC-Att loss definition
  863. if self.ctc_weight == 0.0:
  864. loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  865. elif self.ctc_weight == 1.0:
  866. loss = loss_ctc
  867. else:
  868. loss = self.ctc_weight * loss_ctc + (
  869. 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
  870. # Collect Attn branch stats
  871. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  872. stats["acc"] = acc_att
  873. stats["cer"] = cer_att
  874. stats["wer"] = wer_att
  875. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
  876. stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
  877. stats["loss"] = torch.clone(loss.detach())
  878. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  879. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  880. return loss, stats, weight
  881. class BiCifParaformer(Paraformer):
  882. """
  883. Paraformer model with an extra cif predictor
  884. to conduct accurate timestamp prediction
  885. """
  886. def __init__(
  887. self,
  888. vocab_size: int,
  889. token_list: Union[Tuple[str, ...], List[str]],
  890. frontend: Optional[AbsFrontend],
  891. specaug: Optional[AbsSpecAug],
  892. normalize: Optional[AbsNormalize],
  893. preencoder: Optional[AbsPreEncoder],
  894. encoder: AbsEncoder,
  895. postencoder: Optional[AbsPostEncoder],
  896. decoder: AbsDecoder,
  897. ctc: CTC,
  898. ctc_weight: float = 0.5,
  899. interctc_weight: float = 0.0,
  900. ignore_id: int = -1,
  901. blank_id: int = 0,
  902. sos: int = 1,
  903. eos: int = 2,
  904. lsm_weight: float = 0.0,
  905. length_normalized_loss: bool = False,
  906. report_cer: bool = True,
  907. report_wer: bool = True,
  908. sym_space: str = "<space>",
  909. sym_blank: str = "<blank>",
  910. extract_feats_in_collect_stats: bool = True,
  911. predictor = None,
  912. predictor_weight: float = 0.0,
  913. predictor_bias: int = 0,
  914. sampling_ratio: float = 0.2,
  915. ):
  916. assert check_argument_types()
  917. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  918. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  919. super().__init__(
  920. vocab_size=vocab_size,
  921. token_list=token_list,
  922. frontend=frontend,
  923. specaug=specaug,
  924. normalize=normalize,
  925. preencoder=preencoder,
  926. encoder=encoder,
  927. postencoder=postencoder,
  928. decoder=decoder,
  929. ctc=ctc,
  930. ctc_weight=ctc_weight,
  931. interctc_weight=interctc_weight,
  932. ignore_id=ignore_id,
  933. blank_id=blank_id,
  934. sos=sos,
  935. eos=eos,
  936. lsm_weight=lsm_weight,
  937. length_normalized_loss=length_normalized_loss,
  938. report_cer=report_cer,
  939. report_wer=report_wer,
  940. sym_space=sym_space,
  941. sym_blank=sym_blank,
  942. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  943. predictor=predictor,
  944. predictor_weight=predictor_weight,
  945. predictor_bias=predictor_bias,
  946. sampling_ratio=sampling_ratio,
  947. )
  948. assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
  949. def _calc_pre2_loss(
  950. self,
  951. encoder_out: torch.Tensor,
  952. encoder_out_lens: torch.Tensor,
  953. ys_pad: torch.Tensor,
  954. ys_pad_lens: torch.Tensor,
  955. ):
  956. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  957. encoder_out.device)
  958. if self.predictor_bias == 1:
  959. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  960. ys_pad_lens = ys_pad_lens + self.predictor_bias
  961. _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
  962. # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  963. loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
  964. return loss_pre2
  965. def _calc_att_loss(
  966. self,
  967. encoder_out: torch.Tensor,
  968. encoder_out_lens: torch.Tensor,
  969. ys_pad: torch.Tensor,
  970. ys_pad_lens: torch.Tensor,
  971. ):
  972. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  973. encoder_out.device)
  974. if self.predictor_bias == 1:
  975. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  976. ys_pad_lens = ys_pad_lens + self.predictor_bias
  977. pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
  978. ignore_id=self.ignore_id)
  979. # 0. sampler
  980. decoder_out_1st = None
  981. if self.sampling_ratio > 0.0:
  982. if self.step_cur < 2:
  983. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  984. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  985. pre_acoustic_embeds)
  986. else:
  987. if self.step_cur < 2:
  988. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  989. sematic_embeds = pre_acoustic_embeds
  990. # 1. Forward decoder
  991. decoder_outs = self.decoder(
  992. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
  993. )
  994. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  995. if decoder_out_1st is None:
  996. decoder_out_1st = decoder_out
  997. # 2. Compute attention loss
  998. loss_att = self.criterion_att(decoder_out, ys_pad)
  999. acc_att = th_accuracy(
  1000. decoder_out_1st.view(-1, self.vocab_size),
  1001. ys_pad,
  1002. ignore_label=self.ignore_id,
  1003. )
  1004. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1005. # Compute cer/wer using attention-decoder
  1006. if self.training or self.error_calculator is None:
  1007. cer_att, wer_att = None, None
  1008. else:
  1009. ys_hat = decoder_out_1st.argmax(dim=-1)
  1010. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1011. return loss_att, acc_att, cer_att, wer_att, loss_pre
  1012. def calc_predictor(self, encoder_out, encoder_out_lens):
  1013. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1014. encoder_out.device)
  1015. pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
  1016. ignore_id=self.ignore_id)
  1017. return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
  1018. def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
  1019. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1020. encoder_out.device)
  1021. ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
  1022. encoder_out_mask,
  1023. token_num)
  1024. return ds_alphas, ds_cif_peak, us_alphas, us_peaks
  1025. def forward(
  1026. self,
  1027. speech: torch.Tensor,
  1028. speech_lengths: torch.Tensor,
  1029. text: torch.Tensor,
  1030. text_lengths: torch.Tensor,
  1031. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1032. """Frontend + Encoder + Decoder + Calc loss
  1033. Args:
  1034. speech: (Batch, Length, ...)
  1035. speech_lengths: (Batch, )
  1036. text: (Batch, Length)
  1037. text_lengths: (Batch,)
  1038. """
  1039. assert text_lengths.dim() == 1, text_lengths.shape
  1040. # Check that batch_size is unified
  1041. assert (
  1042. speech.shape[0]
  1043. == speech_lengths.shape[0]
  1044. == text.shape[0]
  1045. == text_lengths.shape[0]
  1046. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1047. batch_size = speech.shape[0]
  1048. self.step_cur += 1
  1049. # for data-parallel
  1050. text = text[:, : text_lengths.max()]
  1051. speech = speech[:, :speech_lengths.max()]
  1052. # 1. Encoder
  1053. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1054. intermediate_outs = None
  1055. if isinstance(encoder_out, tuple):
  1056. intermediate_outs = encoder_out[1]
  1057. encoder_out = encoder_out[0]
  1058. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1059. loss_ctc, cer_ctc = None, None
  1060. loss_pre = None
  1061. stats = dict()
  1062. # 1. CTC branch
  1063. if self.ctc_weight != 0.0:
  1064. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1065. encoder_out, encoder_out_lens, text, text_lengths
  1066. )
  1067. # Collect CTC branch stats
  1068. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1069. stats["cer_ctc"] = cer_ctc
  1070. # Intermediate CTC (optional)
  1071. loss_interctc = 0.0
  1072. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1073. for layer_idx, intermediate_out in intermediate_outs:
  1074. # we assume intermediate_out has the same length & padding
  1075. # as those of encoder_out
  1076. loss_ic, cer_ic = self._calc_ctc_loss(
  1077. intermediate_out, encoder_out_lens, text, text_lengths
  1078. )
  1079. loss_interctc = loss_interctc + loss_ic
  1080. # Collect Intermedaite CTC stats
  1081. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1082. loss_ic.detach() if loss_ic is not None else None
  1083. )
  1084. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1085. loss_interctc = loss_interctc / len(intermediate_outs)
  1086. # calculate whole encoder loss
  1087. loss_ctc = (
  1088. 1 - self.interctc_weight
  1089. ) * loss_ctc + self.interctc_weight * loss_interctc
  1090. # 2b. Attention decoder branch
  1091. if self.ctc_weight != 1.0:
  1092. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  1093. encoder_out, encoder_out_lens, text, text_lengths
  1094. )
  1095. loss_pre2 = self._calc_pre2_loss(
  1096. encoder_out, encoder_out_lens, text, text_lengths
  1097. )
  1098. # 3. CTC-Att loss definition
  1099. if self.ctc_weight == 0.0:
  1100. loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  1101. elif self.ctc_weight == 1.0:
  1102. loss = loss_ctc
  1103. else:
  1104. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
  1105. # Collect Attn branch stats
  1106. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1107. stats["acc"] = acc_att
  1108. stats["cer"] = cer_att
  1109. stats["wer"] = wer_att
  1110. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  1111. stats["loss_pre2"] = loss_pre2.detach().cpu()
  1112. stats["loss"] = torch.clone(loss.detach())
  1113. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1114. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1115. return loss, stats, weight
  1116. class ContextualParaformer(Paraformer):
  1117. """
  1118. Paraformer model with contextual hotword
  1119. """
  1120. def __init__(
  1121. self,
  1122. vocab_size: int,
  1123. token_list: Union[Tuple[str, ...], List[str]],
  1124. frontend: Optional[AbsFrontend],
  1125. specaug: Optional[AbsSpecAug],
  1126. normalize: Optional[AbsNormalize],
  1127. preencoder: Optional[AbsPreEncoder],
  1128. encoder: AbsEncoder,
  1129. postencoder: Optional[AbsPostEncoder],
  1130. decoder: AbsDecoder,
  1131. ctc: CTC,
  1132. ctc_weight: float = 0.5,
  1133. interctc_weight: float = 0.0,
  1134. ignore_id: int = -1,
  1135. blank_id: int = 0,
  1136. sos: int = 1,
  1137. eos: int = 2,
  1138. lsm_weight: float = 0.0,
  1139. length_normalized_loss: bool = False,
  1140. report_cer: bool = True,
  1141. report_wer: bool = True,
  1142. sym_space: str = "<space>",
  1143. sym_blank: str = "<blank>",
  1144. extract_feats_in_collect_stats: bool = True,
  1145. predictor=None,
  1146. predictor_weight: float = 0.0,
  1147. predictor_bias: int = 0,
  1148. sampling_ratio: float = 0.2,
  1149. min_hw_length: int = 2,
  1150. max_hw_length: int = 4,
  1151. sample_rate: float = 0.6,
  1152. batch_rate: float = 0.5,
  1153. double_rate: float = -1.0,
  1154. target_buffer_length: int = -1,
  1155. inner_dim: int = 256,
  1156. bias_encoder_type: str = 'lstm',
  1157. label_bracket: bool = False,
  1158. use_decoder_embedding: bool = False,
  1159. ):
  1160. assert check_argument_types()
  1161. assert 0.0 <= ctc_weight <= 1.0, ctc_weight
  1162. assert 0.0 <= interctc_weight < 1.0, interctc_weight
  1163. super().__init__(
  1164. vocab_size=vocab_size,
  1165. token_list=token_list,
  1166. frontend=frontend,
  1167. specaug=specaug,
  1168. normalize=normalize,
  1169. preencoder=preencoder,
  1170. encoder=encoder,
  1171. postencoder=postencoder,
  1172. decoder=decoder,
  1173. ctc=ctc,
  1174. ctc_weight=ctc_weight,
  1175. interctc_weight=interctc_weight,
  1176. ignore_id=ignore_id,
  1177. blank_id=blank_id,
  1178. sos=sos,
  1179. eos=eos,
  1180. lsm_weight=lsm_weight,
  1181. length_normalized_loss=length_normalized_loss,
  1182. report_cer=report_cer,
  1183. report_wer=report_wer,
  1184. sym_space=sym_space,
  1185. sym_blank=sym_blank,
  1186. extract_feats_in_collect_stats=extract_feats_in_collect_stats,
  1187. predictor=predictor,
  1188. predictor_weight=predictor_weight,
  1189. predictor_bias=predictor_bias,
  1190. sampling_ratio=sampling_ratio,
  1191. )
  1192. if bias_encoder_type == 'lstm':
  1193. logging.warning("enable bias encoder sampling and contextual training")
  1194. self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
  1195. self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
  1196. else:
  1197. logging.error("Unsupport bias encoder type")
  1198. self.min_hw_length = min_hw_length
  1199. self.max_hw_length = max_hw_length
  1200. self.sample_rate = sample_rate
  1201. self.batch_rate = batch_rate
  1202. self.target_buffer_length = target_buffer_length
  1203. self.double_rate = double_rate
  1204. if self.target_buffer_length > 0:
  1205. self.hotword_buffer = None
  1206. self.length_record = []
  1207. self.current_buffer_length = 0
  1208. self.use_decoder_embedding = use_decoder_embedding
  1209. def forward(
  1210. self,
  1211. speech: torch.Tensor,
  1212. speech_lengths: torch.Tensor,
  1213. text: torch.Tensor,
  1214. text_lengths: torch.Tensor,
  1215. ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
  1216. """Frontend + Encoder + Decoder + Calc loss
  1217. Args:
  1218. speech: (Batch, Length, ...)
  1219. speech_lengths: (Batch, )
  1220. text: (Batch, Length)
  1221. text_lengths: (Batch,)
  1222. """
  1223. assert text_lengths.dim() == 1, text_lengths.shape
  1224. # Check that batch_size is unified
  1225. assert (
  1226. speech.shape[0]
  1227. == speech_lengths.shape[0]
  1228. == text.shape[0]
  1229. == text_lengths.shape[0]
  1230. ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
  1231. batch_size = speech.shape[0]
  1232. self.step_cur += 1
  1233. # for data-parallel
  1234. text = text[:, : text_lengths.max()]
  1235. speech = speech[:, :speech_lengths.max()]
  1236. # 1. Encoder
  1237. encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
  1238. intermediate_outs = None
  1239. if isinstance(encoder_out, tuple):
  1240. intermediate_outs = encoder_out[1]
  1241. encoder_out = encoder_out[0]
  1242. loss_att, acc_att, cer_att, wer_att = None, None, None, None
  1243. loss_ctc, cer_ctc = None, None
  1244. loss_pre = None
  1245. stats = dict()
  1246. # 1. CTC branch
  1247. if self.ctc_weight != 0.0:
  1248. loss_ctc, cer_ctc = self._calc_ctc_loss(
  1249. encoder_out, encoder_out_lens, text, text_lengths
  1250. )
  1251. # Collect CTC branch stats
  1252. stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
  1253. stats["cer_ctc"] = cer_ctc
  1254. # Intermediate CTC (optional)
  1255. loss_interctc = 0.0
  1256. if self.interctc_weight != 0.0 and intermediate_outs is not None:
  1257. for layer_idx, intermediate_out in intermediate_outs:
  1258. # we assume intermediate_out has the same length & padding
  1259. # as those of encoder_out
  1260. loss_ic, cer_ic = self._calc_ctc_loss(
  1261. intermediate_out, encoder_out_lens, text, text_lengths
  1262. )
  1263. loss_interctc = loss_interctc + loss_ic
  1264. # Collect Intermedaite CTC stats
  1265. stats["loss_interctc_layer{}".format(layer_idx)] = (
  1266. loss_ic.detach() if loss_ic is not None else None
  1267. )
  1268. stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
  1269. loss_interctc = loss_interctc / len(intermediate_outs)
  1270. # calculate whole encoder loss
  1271. loss_ctc = (
  1272. 1 - self.interctc_weight
  1273. ) * loss_ctc + self.interctc_weight * loss_interctc
  1274. # 2b. Attention decoder branch
  1275. if self.ctc_weight != 1.0:
  1276. loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
  1277. encoder_out, encoder_out_lens, text, text_lengths
  1278. )
  1279. # 3. CTC-Att loss definition
  1280. if self.ctc_weight == 0.0:
  1281. loss = loss_att + loss_pre * self.predictor_weight
  1282. elif self.ctc_weight == 1.0:
  1283. loss = loss_ctc
  1284. else:
  1285. loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
  1286. # Collect Attn branch stats
  1287. stats["loss_att"] = loss_att.detach() if loss_att is not None else None
  1288. stats["acc"] = acc_att
  1289. stats["cer"] = cer_att
  1290. stats["wer"] = wer_att
  1291. stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
  1292. stats["loss"] = torch.clone(loss.detach())
  1293. # force_gatherable: to-device and to-tensor if scalar for DataParallel
  1294. loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
  1295. return loss, stats, weight
  1296. def _sample_hot_word(self, ys_pad, ys_pad_lens):
  1297. hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
  1298. hw_lengths = [0] # this length is actually for indice, so -1
  1299. for i, length in enumerate(ys_pad_lens):
  1300. if length < 2:
  1301. continue
  1302. if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
  1303. # sample double hotword
  1304. _max_hw_length = min(self.max_hw_length, length // 2)
  1305. # first hotword
  1306. start1 = random.randint(0, length // 3)
  1307. end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
  1308. hw_tokens1 = ys_pad[i][start1:end1 + 1]
  1309. hw_lengths.append(len(hw_tokens1) - 1)
  1310. hw_list.append(hw_tokens1)
  1311. # second hotword
  1312. start2 = random.randint(end1 + 1, length - self.min_hw_length)
  1313. end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
  1314. min(length - 1, start2 + self.max_hw_length - 1))
  1315. hw_tokens2 = ys_pad[i][start2:end2 + 1]
  1316. hw_lengths.append(len(hw_tokens2) - 1)
  1317. hw_list.append(hw_tokens2)
  1318. continue
  1319. if random.random() < self.sample_rate:
  1320. if length == 2:
  1321. hw_tokens = ys_pad[i][:2]
  1322. hw_lengths.append(1)
  1323. hw_list.append(hw_tokens)
  1324. else:
  1325. start = random.randint(0, length - self.min_hw_length)
  1326. end = random.randint(min(length - 1, start + self.min_hw_length - 1),
  1327. min(length - 1, start + self.max_hw_length - 1)) + 1
  1328. # print(start, end)
  1329. hw_tokens = ys_pad[i][start:end]
  1330. hw_lengths.append(len(hw_tokens) - 1)
  1331. hw_list.append(hw_tokens)
  1332. # padding
  1333. hw_list_pad = pad_list(hw_list, 0)
  1334. if self.use_decoder_embedding:
  1335. hw_embed = self.decoder.embed(hw_list_pad)
  1336. else:
  1337. hw_embed = self.bias_embed(hw_list_pad)
  1338. hw_embed, (_, _) = self.bias_encoder(hw_embed)
  1339. _ind = np.arange(0, len(hw_list)).tolist()
  1340. # update self.hotword_buffer, throw a part if oversize
  1341. selected = hw_embed[_ind, hw_lengths]
  1342. if self.target_buffer_length > 0:
  1343. _b = selected.shape[0]
  1344. if self.hotword_buffer is None:
  1345. self.hotword_buffer = selected
  1346. self.length_record.append(selected.shape[0])
  1347. self.current_buffer_length = _b
  1348. elif self.current_buffer_length + _b < self.target_buffer_length:
  1349. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1350. self.current_buffer_length += _b
  1351. selected = self.hotword_buffer
  1352. else:
  1353. self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
  1354. random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
  1355. self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
  1356. selected = self.hotword_buffer
  1357. self.current_buffer_length = selected.shape[0]
  1358. return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
  1359. def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
  1360. tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
  1361. ys_pad = ys_pad * tgt_mask[:, :, 0]
  1362. if self.share_embedding:
  1363. ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
  1364. else:
  1365. ys_pad_embed = self.decoder.embed(ys_pad)
  1366. with torch.no_grad():
  1367. decoder_outs = self.decoder(
  1368. encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
  1369. )
  1370. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1371. pred_tokens = decoder_out.argmax(-1)
  1372. nonpad_positions = ys_pad.ne(self.ignore_id)
  1373. seq_lens = (nonpad_positions).sum(1)
  1374. same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
  1375. input_mask = torch.ones_like(nonpad_positions)
  1376. bsz, seq_len = ys_pad.size()
  1377. for li in range(bsz):
  1378. target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
  1379. if target_num > 0:
  1380. input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
  1381. input_mask = input_mask.eq(1)
  1382. input_mask = input_mask.masked_fill(~nonpad_positions, False)
  1383. input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
  1384. sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
  1385. input_mask_expand_dim, 0)
  1386. return sematic_embeds * tgt_mask, decoder_out * tgt_mask
  1387. def _calc_att_loss(
  1388. self,
  1389. encoder_out: torch.Tensor,
  1390. encoder_out_lens: torch.Tensor,
  1391. ys_pad: torch.Tensor,
  1392. ys_pad_lens: torch.Tensor,
  1393. ):
  1394. encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
  1395. encoder_out.device)
  1396. if self.predictor_bias == 1:
  1397. _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
  1398. ys_pad_lens = ys_pad_lens + self.predictor_bias
  1399. pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
  1400. encoder_out_mask,
  1401. ignore_id=self.ignore_id)
  1402. # sample hot word
  1403. contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
  1404. # 0. sampler
  1405. decoder_out_1st = None
  1406. if self.sampling_ratio > 0.0:
  1407. if self.step_cur < 2:
  1408. logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1409. sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
  1410. pre_acoustic_embeds, contextual_info)
  1411. else:
  1412. if self.step_cur < 2:
  1413. logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
  1414. sematic_embeds = pre_acoustic_embeds
  1415. # 1. Forward decoder
  1416. decoder_outs = self.decoder(
  1417. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1418. )
  1419. decoder_out, _ = decoder_outs[0], decoder_outs[1]
  1420. if decoder_out_1st is None:
  1421. decoder_out_1st = decoder_out
  1422. # 2. Compute attention loss
  1423. loss_att = self.criterion_att(decoder_out, ys_pad)
  1424. acc_att = th_accuracy(
  1425. decoder_out_1st.view(-1, self.vocab_size),
  1426. ys_pad,
  1427. ignore_label=self.ignore_id,
  1428. )
  1429. loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
  1430. # Compute cer/wer using attention-decoder
  1431. if self.training or self.error_calculator is None:
  1432. cer_att, wer_att = None, None
  1433. else:
  1434. ys_hat = decoder_out_1st.argmax(dim=-1)
  1435. cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
  1436. return loss_att, acc_att, cer_att, wer_att, loss_pre
  1437. def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
  1438. if hw_list is None:
  1439. # default hotword list
  1440. hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
  1441. hw_list_pad = pad_list(hw_list, 0)
  1442. if self.use_decoder_embedding:
  1443. hw_embed = self.decoder.embed(hw_list_pad)
  1444. else:
  1445. hw_embed = self.bias_embed(hw_list_pad)
  1446. _, (h_n, _) = self.bias_encoder(hw_embed)
  1447. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1448. else:
  1449. hw_lengths = [len(i) for i in hw_list]
  1450. hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
  1451. if self.use_decoder_embedding:
  1452. hw_embed = self.decoder.embed(hw_list_pad)
  1453. else:
  1454. hw_embed = self.bias_embed(hw_list_pad)
  1455. hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
  1456. enforce_sorted=False)
  1457. _, (h_n, _) = self.bias_encoder(hw_embed)
  1458. # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
  1459. contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
  1460. decoder_outs = self.decoder(
  1461. encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
  1462. )
  1463. decoder_out = decoder_outs[0]
  1464. decoder_out = torch.log_softmax(decoder_out, dim=-1)
  1465. return decoder_out, ys_pad_lens
  1466. def gen_clas_tf2torch_map_dict(self):
  1467. tensor_name_prefix_torch = "bias_encoder"
  1468. tensor_name_prefix_tf = "seq2seq/clas_charrnn"
  1469. tensor_name_prefix_torch_emb = "bias_embed"
  1470. tensor_name_prefix_tf_emb = "seq2seq"
  1471. map_dict_local = {
  1472. # in lstm
  1473. "{}.weight_ih_l0".format(tensor_name_prefix_torch):
  1474. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1475. "squeeze": None,
  1476. "transpose": (1, 0),
  1477. "slice": (0, 512),
  1478. "unit_k": 512,
  1479. }, # (1024, 2048),(2048,512)
  1480. "{}.weight_hh_l0".format(tensor_name_prefix_torch):
  1481. {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
  1482. "squeeze": None,
  1483. "transpose": (1, 0),
  1484. "slice": (512, 1024),
  1485. "unit_k": 512,
  1486. }, # (1024, 2048),(2048,512)
  1487. "{}.bias_ih_l0".format(tensor_name_prefix_torch):
  1488. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1489. "squeeze": None,
  1490. "transpose": None,
  1491. "scale": 0.5,
  1492. "unit_b": 512,
  1493. }, # (2048,),(2048,)
  1494. "{}.bias_hh_l0".format(tensor_name_prefix_torch):
  1495. {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
  1496. "squeeze": None,
  1497. "transpose": None,
  1498. "scale": 0.5,
  1499. "unit_b": 512,
  1500. }, # (2048,),(2048,)
  1501. # in embed
  1502. "{}.weight".format(tensor_name_prefix_torch_emb):
  1503. {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
  1504. "squeeze": None,
  1505. "transpose": None,
  1506. }, # (4235,256),(4235,256)
  1507. }
  1508. return map_dict_local
  1509. def clas_convert_tf2torch(self,
  1510. var_dict_tf,
  1511. var_dict_torch):
  1512. map_dict = self.gen_clas_tf2torch_map_dict()
  1513. var_dict_torch_update = dict()
  1514. for name in sorted(var_dict_torch.keys(), reverse=False):
  1515. names = name.split('.')
  1516. if names[0] == "bias_encoder":
  1517. name_q = name
  1518. if name_q in map_dict.keys():
  1519. name_v = map_dict[name_q]["name"]
  1520. name_tf = name_v
  1521. data_tf = var_dict_tf[name_tf]
  1522. if map_dict[name_q].get("unit_k") is not None:
  1523. dim = map_dict[name_q]["unit_k"]
  1524. i = data_tf[:, 0:dim].copy()
  1525. f = data_tf[:, dim:2 * dim].copy()
  1526. o = data_tf[:, 2 * dim:3 * dim].copy()
  1527. g = data_tf[:, 3 * dim:4 * dim].copy()
  1528. data_tf = np.concatenate([i, o, f, g], axis=1)
  1529. if map_dict[name_q].get("unit_b") is not None:
  1530. dim = map_dict[name_q]["unit_b"]
  1531. i = data_tf[0:dim].copy()
  1532. f = data_tf[dim:2 * dim].copy()
  1533. o = data_tf[2 * dim:3 * dim].copy()
  1534. g = data_tf[3 * dim:4 * dim].copy()
  1535. data_tf = np.concatenate([i, o, f, g], axis=0)
  1536. if map_dict[name_q]["squeeze"] is not None:
  1537. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  1538. if map_dict[name_q].get("slice") is not None:
  1539. data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
  1540. if map_dict[name_q].get("scale") is not None:
  1541. data_tf = data_tf * map_dict[name_q]["scale"]
  1542. if map_dict[name_q]["transpose"] is not None:
  1543. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  1544. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1545. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1546. var_dict_torch[
  1547. name].size(),
  1548. data_tf.size())
  1549. var_dict_torch_update[name] = data_tf
  1550. logging.info(
  1551. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  1552. var_dict_tf[name_tf].shape))
  1553. elif names[0] == "bias_embed":
  1554. name_tf = map_dict[name]["name"]
  1555. data_tf = var_dict_tf[name_tf]
  1556. if map_dict[name]["squeeze"] is not None:
  1557. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  1558. if map_dict[name]["transpose"] is not None:
  1559. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  1560. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  1561. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  1562. var_dict_torch[
  1563. name].size(),
  1564. data_tf.size())
  1565. var_dict_torch_update[name] = data_tf
  1566. logging.info(
  1567. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  1568. var_dict_tf[name_tf].shape))
  1569. return var_dict_torch_update