| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248 |
- import logging
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Tuple
- from typing import Union
- import torch
- import random
- import numpy as np
- from funasr.layers.abs_normalize import AbsNormalize
- from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
- )
- from funasr.models.ctc import CTC
- from funasr.models.decoder.abs_decoder import AbsDecoder
- from funasr.models.e2e_asr_common import ErrorCalculator
- from funasr.models.encoder.abs_encoder import AbsEncoder
- from funasr.models.frontend.abs_frontend import AbsFrontend
- from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
- from funasr.models.predictor.cif import mae_loss
- from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
- from funasr.models.specaug.abs_specaug import AbsSpecAug
- from funasr.modules.add_sos_eos import add_sos_eos
- from funasr.modules.nets_utils import make_pad_mask, pad_list
- from funasr.modules.nets_utils import th_accuracy
- from funasr.torch_utils.device_funcs import force_gatherable
- from funasr.models.base_model import FunASRModel
- from funasr.models.predictor.cif import CifPredictorV3
- if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
- class Paraformer(FunASRModel):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- share_embedding: bool = False,
- preencoder: Optional[AbsPreEncoder] = None,
- postencoder: Optional[AbsPostEncoder] = None,
- use_1st_decoder_loss: bool = False,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__()
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = vocab_size - 1 if sos is None else sos
- self.eos = vocab_size - 1 if eos is None else eos
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.interctc_weight = interctc_weight
- self.token_list = token_list.copy()
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
- self.encoder = encoder
- if not hasattr(self.encoder, "interctc_use_conditioning"):
- self.encoder.interctc_use_conditioning = False
- if self.encoder.interctc_use_conditioning:
- self.encoder.conditioning_layer = torch.nn.Linear(
- vocab_size, self.encoder.output_size()
- )
- self.error_calculator = None
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- if report_cer or report_wer:
- self.error_calculator = ErrorCalculator(
- token_list, sym_space, sym_blank, report_cer, report_wer
- )
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- self.step_cur = 0
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
- self.use_1st_decoder_loss = use_1st_decoder_loss
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- decoding_ind: int = None,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- decoding_ind: int
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, pre_loss_att, acc_att, cer_att, wer_att = None, None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + (1 - self.ctc_weight) * pre_loss_att
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc, ind=ind
- )
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
- else:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, encoder_out_lens
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
- def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ) -> torch.Tensor:
- """Compute negative log likelihood(nll) from transformer-decoder
- Normally, this function is called in batchify_nll.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- """
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- ) # [batch, seqlen, dim]
- batch_size = decoder_out.size(0)
- decoder_num_class = decoder_out.size(2)
- # nll: negative log-likelihood
- nll = torch.nn.functional.cross_entropy(
- decoder_out.view(-1, decoder_num_class),
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="none",
- )
- nll = nll.view(batch_size, -1)
- nll = nll.sum(dim=1)
- assert nll.size(0) == batch_size
- return nll
- def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
- ):
- """Compute negative log likelihood(nll) from transformer-decoder
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
- GPU memory usage
- """
- total_num = encoder_out.size(0)
- if total_num <= batch_size:
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- else:
- nll = []
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
- batch_nll = self.nll(
- batch_encoder_out,
- batch_encoder_out_lens,
- batch_ys_pad,
- batch_ys_pad_lens,
- )
- nll.append(batch_nll)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nll)
- assert nll.size(0) == total_num
- return nll
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
- def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
- class ParaformerOnline(Paraformer):
- """
- Author: Speech Lab, Alibaba Group, China
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- decoder_attention_chunk_type: str = 'chunk',
- share_embedding: bool = False,
- preencoder: Optional[AbsPreEncoder] = None,
- postencoder: Optional[AbsPostEncoder] = None,
- use_1st_decoder_loss: bool = False,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = vocab_size - 1 if sos is None else sos
- self.eos = vocab_size - 1 if eos is None else eos
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.interctc_weight = interctc_weight
- self.token_list = token_list.copy()
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
- self.encoder = encoder
- if not hasattr(self.encoder, "interctc_use_conditioning"):
- self.encoder.interctc_use_conditioning = False
- if self.encoder.interctc_use_conditioning:
- self.encoder.conditioning_layer = torch.nn.Linear(
- vocab_size, self.encoder.output_size()
- )
- self.error_calculator = None
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- if report_cer or report_wer:
- self.error_calculator = ErrorCalculator(
- token_list, sym_space, sym_blank, report_cer, report_wer
- )
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- self.step_cur = 0
- self.scama_mask = None
- if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
- from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
- self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
- self.decoder_attention_chunk_type = decoder_attention_chunk_type
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
- self.use_1st_decoder_loss = use_1st_decoder_loss
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- decoding_ind: int = None,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- decoding_ind: int
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- if hasattr(self.encoder, "overlap_chunk_cls"):
- ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- if hasattr(self.encoder, "overlap_chunk_cls"):
- encoder_out_ctc, encoder_out_lens_ctc = \
- self.encoder.overlap_chunk_cls.remove_chunk(
- intermediate_out,
- encoder_out_lens,
- chunk_outs=None)
- loss_ic, cer_ic = self._calc_ctc_loss(
- encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- if self.use_1st_decoder_loss and pre_loss_att is not None:
- loss = loss + pre_loss_att
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc, ind=ind
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, encoder_out_lens
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
- feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, torch.tensor([encoder_out.size(1)])
- def _calc_att_predictor_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
- ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=ys_pad_lens,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens)
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=ys_pad_lens,
- is_training=self.training,
- )
- elif self.encoder.overlap_chunk_cls is not None:
- encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
- encoder_out_lens,
- chunk_outs=None)
- # 0. sampler
- decoder_out_1st = None
- pre_loss_att = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- if self.use_1st_decoder_loss:
- sematic_embeds, decoder_out_1st, pre_loss_att = \
- self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- sematic_embeds, decoder_out_1st = \
- self.sampler(encoder_out, encoder_out_lens, ys_pad,
- ys_pad_lens, pre_acoustic_embeds, scama_mask)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
- def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
- )
- pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- mask_chunk_predictor = None
- if self.encoder.overlap_chunk_cls is not None:
- mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(
- 0))
- mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
- batch_size=encoder_out.size(0))
- encoder_out = encoder_out * mask_shfit_chunk
- pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
- None,
- encoder_out_mask,
- ignore_id=self.ignore_id,
- mask_chunk_predictor=mask_chunk_predictor,
- target_label_length=None,
- )
- predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
- encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
- scama_mask = None
- if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
- encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
- attention_chunk_center_bias = 0
- attention_chunk_size = encoder_chunk_size
- decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
- mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
- get_mask_shift_att_chunk_decoder(None,
- device=encoder_out.device,
- batch_size=encoder_out.size(0)
- )
- scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
- predictor_alignments=predictor_alignments,
- encoder_sequence_length=encoder_out_lens,
- chunk_size=1,
- encoder_chunk_size=encoder_chunk_size,
- attention_chunk_center_bias=attention_chunk_center_bias,
- attention_chunk_size=attention_chunk_size,
- attention_chunk_type=self.decoder_attention_chunk_type,
- step=None,
- predictor_mask_chunk_hopping=mask_chunk_predictor,
- decoder_att_look_back_factor=decoder_att_look_back_factor,
- mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
- target_length=None,
- is_training=self.training,
- )
- self.scama_mask = scama_mask
- return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
- def calc_predictor_chunk(self, encoder_out, cache=None):
- pre_acoustic_embeds, pre_token_length = \
- self.predictor.forward_chunk(encoder_out, cache["encoder"])
- return pre_acoustic_embeds, pre_token_length
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
- class ParaformerBert(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- embeds_id: int = 2,
- embeds_loss_weight: float = 0.0,
- embed_dims: int = 768,
- preencoder: Optional[AbsPreEncoder] = None,
- postencoder: Optional[AbsPostEncoder] = None,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- self.decoder.embeds_id = embeds_id
- decoder_attention_dim = self.decoder.attention_dim
- self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
- self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
- self.embeds_loss_weight = embeds_loss_weight
- self.length_normalized_loss = length_normalized_loss
- def _calc_embed_loss(self,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- embed: torch.Tensor = None,
- embed_lengths: torch.Tensor = None,
- embeds_outputs: torch.Tensor = None,
- ):
- embeds_outputs = self.pro_nn(embeds_outputs)
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- embeds_outputs *= tgt_mask # b x l x d
- embed *= tgt_mask # b x l x d
- cos_loss = 1.0 - self.cos(embeds_outputs, embed)
- cos_loss *= tgt_mask.squeeze(2)
- if self.length_normalized_loss:
- token_num_total = torch.sum(tgt_mask)
- else:
- token_num_total = tgt_mask.size()[0]
- cos_loss_total = torch.sum(cos_loss)
- cos_loss = cos_loss_total / token_num_total
- # print("cos_loss: {}".format(cos_loss))
- return cos_loss
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info(
- "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info(
- "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- embeds_outputs = None
- if len(decoder_outs) > 2:
- embeds_outputs = decoder_outs[2]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- embed: torch.Tensor = None,
- embed_lengths: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- if embed is not None:
- embed = embed[:, :embed_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = 0.0
- cos_loss = 0.0
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_ret = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
- loss_ret[4]
- embeds_outputs = None
- if len(loss_ret) > 5:
- embeds_outputs = loss_ret[5]
- if embeds_outputs is not None:
- cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
- stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- class BiCifParaformer(Paraformer):
- """
- Paraformer model with an extra cif predictor
- to conduct accurate timestamp prediction
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- preencoder: Optional[AbsPreEncoder] = None,
- postencoder: Optional[AbsPostEncoder] = None,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
- def _calc_pre2_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
- # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
- return loss_pre2
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
- None,
- encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
- def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
- return ds_alphas, ds_cif_peak, us_alphas, us_peaks
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- loss_pre2 = self._calc_pre2_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss_pre2"] = loss_pre2.detach().cpu()
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- class ContextualParaformer(Paraformer):
- """
- Paraformer model with contextual hotword
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- min_hw_length: int = 2,
- max_hw_length: int = 4,
- sample_rate: float = 0.6,
- batch_rate: float = 0.5,
- double_rate: float = -1.0,
- target_buffer_length: int = -1,
- inner_dim: int = 256,
- bias_encoder_type: str = 'lstm',
- label_bracket: bool = False,
- use_decoder_embedding: bool = False,
- preencoder: Optional[AbsPreEncoder] = None,
- postencoder: Optional[AbsPostEncoder] = None,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- if bias_encoder_type == 'lstm':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
- self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
- else:
- logging.error("Unsupport bias encoder type")
- self.min_hw_length = min_hw_length
- self.max_hw_length = max_hw_length
- self.sample_rate = sample_rate
- self.batch_rate = batch_rate
- self.target_buffer_length = target_buffer_length
- self.double_rate = double_rate
- if self.target_buffer_length > 0:
- self.hotword_buffer = None
- self.length_record = []
- self.current_buffer_length = 0
- self.use_decoder_embedding = use_decoder_embedding
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def _sample_hot_word(self, ys_pad, ys_pad_lens):
- hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
- hw_lengths = [0] # this length is actually for indice, so -1
- for i, length in enumerate(ys_pad_lens):
- if length < 2:
- continue
- if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
- # sample double hotword
- _max_hw_length = min(self.max_hw_length, length // 2)
- # first hotword
- start1 = random.randint(0, length // 3)
- end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
- hw_tokens1 = ys_pad[i][start1:end1 + 1]
- hw_lengths.append(len(hw_tokens1) - 1)
- hw_list.append(hw_tokens1)
- # second hotword
- start2 = random.randint(end1 + 1, length - self.min_hw_length)
- end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
- min(length - 1, start2 + self.max_hw_length - 1))
- hw_tokens2 = ys_pad[i][start2:end2 + 1]
- hw_lengths.append(len(hw_tokens2) - 1)
- hw_list.append(hw_tokens2)
- continue
- if random.random() < self.sample_rate:
- if length == 2:
- hw_tokens = ys_pad[i][:2]
- hw_lengths.append(1)
- hw_list.append(hw_tokens)
- else:
- start = random.randint(0, length - self.min_hw_length)
- end = random.randint(min(length - 1, start + self.min_hw_length - 1),
- min(length - 1, start + self.max_hw_length - 1)) + 1
- # print(start, end)
- hw_tokens = ys_pad[i][start:end]
- hw_lengths.append(len(hw_tokens) - 1)
- hw_list.append(hw_tokens)
- # padding
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed, (_, _) = self.bias_encoder(hw_embed)
- _ind = np.arange(0, len(hw_list)).tolist()
- # update self.hotword_buffer, throw a part if oversize
- selected = hw_embed[_ind, hw_lengths]
- if self.target_buffer_length > 0:
- _b = selected.shape[0]
- if self.hotword_buffer is None:
- self.hotword_buffer = selected
- self.length_record.append(selected.shape[0])
- self.current_buffer_length = _b
- elif self.current_buffer_length + _b < self.target_buffer_length:
- self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
- self.current_buffer_length += _b
- selected = self.hotword_buffer
- else:
- self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
- random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
- self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
- selected = self.hotword_buffer
- self.current_buffer_length = selected.shape[0]
- return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id)
- # sample hot word
- contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds, contextual_info)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None, clas_scale=1.0):
- if hw_list is None:
- # default hotword list
- hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
- else:
- hw_lengths = [len(i) for i in hw_list]
- hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
- enforce_sorted=False)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
- contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
- def gen_clas_tf2torch_map_dict(self):
- tensor_name_prefix_torch = "bias_encoder"
- tensor_name_prefix_tf = "seq2seq/clas_charrnn"
- tensor_name_prefix_torch_emb = "bias_embed"
- tensor_name_prefix_tf_emb = "seq2seq"
- map_dict_local = {
- # in lstm
- "{}.weight_ih_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (1, 0),
- "slice": (0, 512),
- "unit_k": 512,
- }, # (1024, 2048),(2048,512)
- "{}.weight_hh_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (1, 0),
- "slice": (512, 1024),
- "unit_k": 512,
- }, # (1024, 2048),(2048,512)
- "{}.bias_ih_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- "scale": 0.5,
- "unit_b": 512,
- }, # (2048,),(2048,)
- "{}.bias_hh_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- "scale": 0.5,
- "unit_b": 512,
- }, # (2048,),(2048,)
- # in embed
- "{}.weight".format(tensor_name_prefix_torch_emb):
- {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
- }
- return map_dict_local
- def clas_convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch):
- map_dict = self.gen_clas_tf2torch_map_dict()
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == "bias_encoder":
- name_q = name
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q].get("unit_k") is not None:
- dim = map_dict[name_q]["unit_k"]
- i = data_tf[:, 0:dim].copy()
- f = data_tf[:, dim:2 * dim].copy()
- o = data_tf[:, 2 * dim:3 * dim].copy()
- g = data_tf[:, 3 * dim:4 * dim].copy()
- data_tf = np.concatenate([i, o, f, g], axis=1)
- if map_dict[name_q].get("unit_b") is not None:
- dim = map_dict[name_q]["unit_b"]
- i = data_tf[0:dim].copy()
- f = data_tf[dim:2 * dim].copy()
- o = data_tf[2 * dim:3 * dim].copy()
- g = data_tf[3 * dim:4 * dim].copy()
- data_tf = np.concatenate([i, o, f, g], axis=0)
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q].get("slice") is not None:
- data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
- if map_dict[name_q].get("scale") is not None:
- data_tf = data_tf * map_dict[name_q]["scale"]
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[0] == "bias_embed":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
- return var_dict_torch_update
|