| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764 |
- import logging
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Tuple
- from typing import Union
- import torch
- import random
- import numpy as np
- from typeguard import check_argument_types
- from funasr.layers.abs_normalize import AbsNormalize
- from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
- )
- from funasr.models.ctc import CTC
- from funasr.models.decoder.abs_decoder import AbsDecoder
- from funasr.models.e2e_asr_common import ErrorCalculator
- from funasr.models.encoder.abs_encoder import AbsEncoder
- from funasr.models.frontend.abs_frontend import AbsFrontend
- from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
- from funasr.models.predictor.cif import mae_loss
- from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
- from funasr.models.specaug.abs_specaug import AbsSpecAug
- from funasr.modules.add_sos_eos import add_sos_eos
- from funasr.modules.nets_utils import make_pad_mask, pad_list
- from funasr.modules.nets_utils import th_accuracy
- from funasr.torch_utils.device_funcs import force_gatherable
- from funasr.train.abs_espnet_model import AbsESPnetModel
- from funasr.models.predictor.cif import CifPredictorV3
- if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
- class Paraformer(AbsESPnetModel):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- share_embedding: bool = False,
- ):
- assert check_argument_types()
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__()
- # note that eos is the same as sos (equivalent ID)
- self.blank_id = blank_id
- self.sos = vocab_size - 1 if sos is None else sos
- self.eos = vocab_size - 1 if eos is None else eos
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.interctc_weight = interctc_weight
- self.token_list = token_list.copy()
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
- self.encoder = encoder
- if not hasattr(self.encoder, "interctc_use_conditioning"):
- self.encoder.interctc_use_conditioning = False
- if self.encoder.interctc_use_conditioning:
- self.encoder.conditioning_layer = torch.nn.Linear(
- vocab_size, self.encoder.output_size()
- )
- self.error_calculator = None
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- if report_cer or report_wer:
- self.error_calculator = ErrorCalculator(
- token_list, sym_space, sym_blank, report_cer, report_wer
- )
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
- self.predictor = predictor
- self.predictor_weight = predictor_weight
- self.predictor_bias = predictor_bias
- self.sampling_ratio = sampling_ratio
- self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
- self.step_cur = 0
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- if self.extract_feats_in_collect_stats:
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- else:
- # Generate dummy stats if extract_feats_in_collect_stats is False
- logging.warning(
- "Generating dummy stats for feats and feats_lengths, "
- "because encoder_conf.extract_feats_in_collect_stats is "
- f"{self.extract_feats_in_collect_stats}"
- )
- feats, feats_lengths = speech, speech_lengths
- return {"feats": feats, "feats_lengths": feats_lengths}
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- feats, feats_lengths, ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, encoder_out_lens
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
- def nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ) -> torch.Tensor:
- """Compute negative log likelihood(nll) from transformer-decoder
- Normally, this function is called in batchify_nll.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- """
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- ) # [batch, seqlen, dim]
- batch_size = decoder_out.size(0)
- decoder_num_class = decoder_out.size(2)
- # nll: negative log-likelihood
- nll = torch.nn.functional.cross_entropy(
- decoder_out.view(-1, decoder_num_class),
- ys_out_pad.view(-1),
- ignore_index=self.ignore_id,
- reduction="none",
- )
- nll = nll.view(batch_size, -1)
- nll = nll.sum(dim=1)
- assert nll.size(0) == batch_size
- return nll
- def batchify_nll(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- batch_size: int = 100,
- ):
- """Compute negative log likelihood(nll) from transformer-decoder
- To avoid OOM, this fuction seperate the input into batches.
- Then call nll for each batch and combine and return results.
- Args:
- encoder_out: (Batch, Length, Dim)
- encoder_out_lens: (Batch,)
- ys_pad: (Batch, Length)
- ys_pad_lens: (Batch,)
- batch_size: int, samples each batch contain when computing nll,
- you may change this to avoid OOM or increase
- GPU memory usage
- """
- total_num = encoder_out.size(0)
- if total_num <= batch_size:
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- else:
- nll = []
- start_idx = 0
- while True:
- end_idx = min(start_idx + batch_size, total_num)
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
- batch_nll = self.nll(
- batch_encoder_out,
- batch_encoder_out_lens,
- batch_ys_pad,
- batch_ys_pad_lens,
- )
- nll.append(batch_nll)
- start_idx = end_idx
- if start_idx == total_num:
- break
- nll = torch.cat(nll)
- assert nll.size(0) == total_num
- return nll
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad_masked)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
- class ParaformerOnline(Paraformer):
- """
- Author: Speech Lab, Alibaba Group, China
- Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
- https://arxiv.org/abs/2206.08317
- """
- def __init__(
- self, *args, **kwargs,
- ):
- super().__init__(*args, **kwargs)
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def encode_chunk(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
- # 4. Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
- feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, torch.tensor([encoder_out.size(1)])
- def calc_predictor_chunk(self, encoder_out, cache=None):
- pre_acoustic_embeds, pre_token_length = \
- self.predictor.forward_chunk(encoder_out, cache["encoder"])
- return pre_acoustic_embeds, pre_token_length
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
- decoder_outs = self.decoder.forward_chunk(
- encoder_out, sematic_embeds, cache["decoder"]
- )
- decoder_out = decoder_outs
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out
- class ParaformerBert(Paraformer):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- embeds_id: int = 2,
- embeds_loss_weight: float = 0.0,
- embed_dims: int = 768,
- ):
- assert check_argument_types()
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- self.decoder.embeds_id = embeds_id
- decoder_attention_dim = self.decoder.attention_dim
- self.pro_nn = torch.nn.Linear(decoder_attention_dim, embed_dims)
- self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
- self.embeds_loss_weight = embeds_loss_weight
- self.length_normalized_loss = length_normalized_loss
- def _calc_embed_loss(self,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- embed: torch.Tensor = None,
- embed_lengths: torch.Tensor = None,
- embeds_outputs: torch.Tensor = None,
- ):
- embeds_outputs = self.pro_nn(embeds_outputs)
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- embeds_outputs *= tgt_mask # b x l x d
- embed *= tgt_mask # b x l x d
- cos_loss = 1.0 - self.cos(embeds_outputs, embed)
- cos_loss *= tgt_mask.squeeze(2)
- if self.length_normalized_loss:
- token_num_total = torch.sum(tgt_mask)
- else:
- token_num_total = tgt_mask.size()[0]
- cos_loss_total = torch.sum(cos_loss)
- cos_loss = cos_loss_total / token_num_total
- # print("cos_loss: {}".format(cos_loss))
- return cos_loss
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info(
- "enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info(
- "disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- embeds_outputs = None
- if len(decoder_outs) > 2:
- embeds_outputs = decoder_outs[2]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre, embeds_outputs
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- embed: torch.Tensor = None,
- embed_lengths: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
- if embed is not None:
- embed = embed[:, :embed_lengths.max(), :]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = 0.0
- cos_loss = 0.0
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_ret = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- loss_att, acc_att, cer_att, wer_att, loss_pre = loss_ret[0], loss_ret[1], loss_ret[2], loss_ret[3], \
- loss_ret[4]
- embeds_outputs = None
- if len(loss_ret) > 5:
- embeds_outputs = loss_ret[5]
- if embeds_outputs is not None:
- cos_loss = self._calc_embed_loss(text, text_lengths, embed, embed_lengths, embeds_outputs)
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (
- 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + cos_loss * self.embeds_loss_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre > 0.0 else None
- stats["cos_loss"] = cos_loss.detach().cpu() if cos_loss > 0.0 else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- class BiCifParaformer(Paraformer):
- """
- Paraformer model with an extra cif predictor
- to conduct accurate timestamp prediction
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor = None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- ):
- assert check_argument_types()
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
- def _calc_pre2_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
- # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
- return loss_pre2
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
- ignore_id=self.ignore_id)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
-
- def calc_predictor(self, encoder_out, encoder_out_lens):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
- def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
- return ds_alphas, ds_cif_peak, us_alphas, us_peaks
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- loss_pre2 = self._calc_pre2_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss_pre2"] = loss_pre2.detach().cpu()
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- class ContextualParaformer(Paraformer):
- """
- Paraformer model with contextual hotword
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor=None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
- min_hw_length: int = 2,
- max_hw_length: int = 4,
- sample_rate: float = 0.6,
- batch_rate: float = 0.5,
- double_rate: float = -1.0,
- target_buffer_length: int = -1,
- inner_dim: int = 256,
- bias_encoder_type: str = 'lstm',
- label_bracket: bool = False,
- use_decoder_embedding: bool = False,
- ):
- assert check_argument_types()
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert 0.0 <= interctc_weight < 1.0, interctc_weight
- super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
- )
- if bias_encoder_type == 'lstm':
- logging.warning("enable bias encoder sampling and contextual training")
- self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
- self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
- else:
- logging.error("Unsupport bias encoder type")
- self.min_hw_length = min_hw_length
- self.max_hw_length = max_hw_length
- self.sample_rate = sample_rate
- self.batch_rate = batch_rate
- self.target_buffer_length = target_buffer_length
- self.double_rate = double_rate
- if self.target_buffer_length > 0:
- self.hotword_buffer = None
- self.length_record = []
- self.current_buffer_length = 0
- self.use_decoder_embedding = use_decoder_embedding
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- batch_size = speech.shape[0]
- self.step_cur += 1
- # for data-parallel
- text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- loss_pre = None
- stats = dict()
- # 1. CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
- loss_interctc = loss_interctc / len(intermediate_outs)
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
- # 2b. Attention decoder branch
- if self.ctc_weight != 1.0:
- loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att + loss_pre * self.predictor_weight
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
- stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
- stats["loss"] = torch.clone(loss.detach())
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def _sample_hot_word(self, ys_pad, ys_pad_lens):
- hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
- hw_lengths = [0] # this length is actually for indice, so -1
- for i, length in enumerate(ys_pad_lens):
- if length < 2:
- continue
- if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
- # sample double hotword
- _max_hw_length = min(self.max_hw_length, length // 2)
- # first hotword
- start1 = random.randint(0, length // 3)
- end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
- hw_tokens1 = ys_pad[i][start1:end1 + 1]
- hw_lengths.append(len(hw_tokens1) - 1)
- hw_list.append(hw_tokens1)
- # second hotword
- start2 = random.randint(end1 + 1, length - self.min_hw_length)
- end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
- min(length - 1, start2 + self.max_hw_length - 1))
- hw_tokens2 = ys_pad[i][start2:end2 + 1]
- hw_lengths.append(len(hw_tokens2) - 1)
- hw_list.append(hw_tokens2)
- continue
- if random.random() < self.sample_rate:
- if length == 2:
- hw_tokens = ys_pad[i][:2]
- hw_lengths.append(1)
- hw_list.append(hw_tokens)
- else:
- start = random.randint(0, length - self.min_hw_length)
- end = random.randint(min(length - 1, start + self.min_hw_length - 1),
- min(length - 1, start + self.max_hw_length - 1)) + 1
- # print(start, end)
- hw_tokens = ys_pad[i][start:end]
- hw_lengths.append(len(hw_tokens) - 1)
- hw_list.append(hw_tokens)
- # padding
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed, (_, _) = self.bias_encoder(hw_embed)
- _ind = np.arange(0, len(hw_list)).tolist()
- # update self.hotword_buffer, throw a part if oversize
- selected = hw_embed[_ind, hw_lengths]
- if self.target_buffer_length > 0:
- _b = selected.shape[0]
- if self.hotword_buffer is None:
- self.hotword_buffer = selected
- self.length_record.append(selected.shape[0])
- self.current_buffer_length = _b
- elif self.current_buffer_length + _b < self.target_buffer_length:
- self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
- self.current_buffer_length += _b
- selected = self.hotword_buffer
- else:
- self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
- random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
- self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
- selected = self.hotword_buffer
- self.current_buffer_length = selected.shape[0]
- return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
- def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
- tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
- ys_pad = ys_pad * tgt_mask[:, :, 0]
- if self.share_embedding:
- ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
- else:
- ys_pad_embed = self.decoder.embed(ys_pad)
- with torch.no_grad():
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- pred_tokens = decoder_out.argmax(-1)
- nonpad_positions = ys_pad.ne(self.ignore_id)
- seq_lens = (nonpad_positions).sum(1)
- same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
- input_mask = torch.ones_like(nonpad_positions)
- bsz, seq_len = ys_pad.size()
- for li in range(bsz):
- target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
- if target_num > 0:
- input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
- input_mask = input_mask.eq(1)
- input_mask = input_mask.masked_fill(~nonpad_positions, False)
- input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
- sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
- input_mask_expand_dim, 0)
- return sematic_embeds * tgt_mask, decoder_out * tgt_mask
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
- encoder_out.device)
- if self.predictor_bias == 1:
- _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_pad_lens = ys_pad_lens + self.predictor_bias
- pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
- encoder_out_mask,
- ignore_id=self.ignore_id)
- # sample hot word
- contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
- # 0. sampler
- decoder_out_1st = None
- if self.sampling_ratio > 0.0:
- if self.step_cur < 2:
- logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
- pre_acoustic_embeds, contextual_info)
- else:
- if self.step_cur < 2:
- logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
- sematic_embeds = pre_acoustic_embeds
- # 1. Forward decoder
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out, _ = decoder_outs[0], decoder_outs[1]
- if decoder_out_1st is None:
- decoder_out_1st = decoder_out
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_pad)
- acc_att = th_accuracy(
- decoder_out_1st.view(-1, self.vocab_size),
- ys_pad,
- ignore_label=self.ignore_id,
- )
- loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out_1st.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att, loss_pre
- def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
- if hw_list is None:
- # default hotword list
- hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)] # empty hotword list
- hw_list_pad = pad_list(hw_list, 0)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
- else:
- hw_lengths = [len(i) for i in hw_list]
- hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
- if self.use_decoder_embedding:
- hw_embed = self.decoder.embed(hw_list_pad)
- else:
- hw_embed = self.bias_embed(hw_list_pad)
- hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
- enforce_sorted=False)
- _, (h_n, _) = self.bias_encoder(hw_embed)
- # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
- contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
- decoder_outs = self.decoder(
- encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
- )
- decoder_out = decoder_outs[0]
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
- return decoder_out, ys_pad_lens
- def gen_clas_tf2torch_map_dict(self):
- tensor_name_prefix_torch = "bias_encoder"
- tensor_name_prefix_tf = "seq2seq/clas_charrnn"
- tensor_name_prefix_torch_emb = "bias_embed"
- tensor_name_prefix_tf_emb = "seq2seq"
- map_dict_local = {
- # in lstm
- "{}.weight_ih_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (1, 0),
- "slice": (0, 512),
- "unit_k": 512,
- }, # (1024, 2048),(2048,512)
- "{}.weight_hh_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": (1, 0),
- "slice": (512, 1024),
- "unit_k": 512,
- }, # (1024, 2048),(2048,512)
- "{}.bias_ih_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- "scale": 0.5,
- "unit_b": 512,
- }, # (2048,),(2048,)
- "{}.bias_hh_l0".format(tensor_name_prefix_torch):
- {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- "scale": 0.5,
- "unit_b": 512,
- }, # (2048,),(2048,)
- # in embed
- "{}.weight".format(tensor_name_prefix_torch_emb):
- {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
- "squeeze": None,
- "transpose": None,
- }, # (4235,256),(4235,256)
- }
- return map_dict_local
- def clas_convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch):
- map_dict = self.gen_clas_tf2torch_map_dict()
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == "bias_encoder":
- name_q = name
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q].get("unit_k") is not None:
- dim = map_dict[name_q]["unit_k"]
- i = data_tf[:, 0:dim].copy()
- f = data_tf[:, dim:2 * dim].copy()
- o = data_tf[:, 2 * dim:3 * dim].copy()
- g = data_tf[:, 3 * dim:4 * dim].copy()
- data_tf = np.concatenate([i, o, f, g], axis=1)
- if map_dict[name_q].get("unit_b") is not None:
- dim = map_dict[name_q]["unit_b"]
- i = data_tf[0:dim].copy()
- f = data_tf[dim:2 * dim].copy()
- o = data_tf[2 * dim:3 * dim].copy()
- g = data_tf[3 * dim:4 * dim].copy()
- data_tf = np.concatenate([i, o, f, g], axis=0)
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q].get("slice") is not None:
- data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
- if map_dict[name_q].get("scale") is not None:
- data_tf = data_tf * map_dict[name_q]["scale"]
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[0] == "bias_embed":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- if map_dict[name]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
- if map_dict[name]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
- return var_dict_torch_update
|