|
|
@@ -325,56 +325,6 @@ class Paraformer(AbsESPnetModel):
|
|
|
|
|
|
return encoder_out, encoder_out_lens
|
|
|
|
|
|
- def encode_chunk(
|
|
|
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- """Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
-
|
|
|
- Args:
|
|
|
- speech: (Batch, Length, ...)
|
|
|
- speech_lengths: (Batch, )
|
|
|
- """
|
|
|
- with autocast(False):
|
|
|
- # 1. Extract feats
|
|
|
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
-
|
|
|
- # 2. Data augmentation
|
|
|
- if self.specaug is not None and self.training:
|
|
|
- feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
-
|
|
|
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
- if self.normalize is not None:
|
|
|
- feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
-
|
|
|
- # Pre-encoder, e.g. used for raw input data
|
|
|
- if self.preencoder is not None:
|
|
|
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
-
|
|
|
- # 4. Forward encoder
|
|
|
- # feats: (Batch, Length, Dim)
|
|
|
- # -> encoder_out: (Batch, Length2, Dim2)
|
|
|
- if self.encoder.interctc_use_conditioning:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
|
|
- feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
|
|
|
- )
|
|
|
- else:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
|
|
- intermediate_outs = None
|
|
|
- if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
- encoder_out = encoder_out[0]
|
|
|
-
|
|
|
- # Post-encoder, e.g. NLU
|
|
|
- if self.postencoder is not None:
|
|
|
- encoder_out, encoder_out_lens = self.postencoder(
|
|
|
- encoder_out, encoder_out_lens
|
|
|
- )
|
|
|
-
|
|
|
- if intermediate_outs is not None:
|
|
|
- return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
-
|
|
|
- return encoder_out, torch.tensor([encoder_out.size(1)])
|
|
|
-
|
|
|
def calc_predictor(self, encoder_out, encoder_out_lens):
|
|
|
|
|
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
|
|
@@ -383,11 +333,6 @@ class Paraformer(AbsESPnetModel):
|
|
|
ignore_id=self.ignore_id)
|
|
|
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
|
|
|
|
|
- def calc_predictor_chunk(self, encoder_out, cache=None):
|
|
|
-
|
|
|
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
|
|
- return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
|
|
-
|
|
|
def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
|
|
|
|
|
|
decoder_outs = self.decoder(
|
|
|
@@ -397,14 +342,6 @@ class Paraformer(AbsESPnetModel):
|
|
|
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
|
|
return decoder_out, ys_pad_lens
|
|
|
|
|
|
- def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
|
|
- decoder_outs = self.decoder.forward_chunk(
|
|
|
- encoder_out, sematic_embeds, cache["decoder"]
|
|
|
- )
|
|
|
- decoder_out = decoder_outs
|
|
|
- decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
|
|
- return decoder_out
|
|
|
-
|
|
|
def _extract_feats(
|
|
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
@@ -610,6 +547,184 @@ class Paraformer(AbsESPnetModel):
|
|
|
return loss_ctc, cer_ctc
|
|
|
|
|
|
|
|
|
+class ParaformerOnline(Paraformer):
|
|
|
+ """
|
|
|
+ Author: Speech Lab, Alibaba Group, China
|
|
|
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
|
|
+ https://arxiv.org/abs/2206.08317
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self, *args, **kwargs,
|
|
|
+ ):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ speech: torch.Tensor,
|
|
|
+ speech_lengths: torch.Tensor,
|
|
|
+ text: torch.Tensor,
|
|
|
+ text_lengths: torch.Tensor,
|
|
|
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
|
+ """Frontend + Encoder + Decoder + Calc loss
|
|
|
+ Args:
|
|
|
+ speech: (Batch, Length, ...)
|
|
|
+ speech_lengths: (Batch, )
|
|
|
+ text: (Batch, Length)
|
|
|
+ text_lengths: (Batch,)
|
|
|
+ """
|
|
|
+ assert text_lengths.dim() == 1, text_lengths.shape
|
|
|
+ # Check that batch_size is unified
|
|
|
+ assert (
|
|
|
+ speech.shape[0]
|
|
|
+ == speech_lengths.shape[0]
|
|
|
+ == text.shape[0]
|
|
|
+ == text_lengths.shape[0]
|
|
|
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
|
|
+ batch_size = speech.shape[0]
|
|
|
+ self.step_cur += 1
|
|
|
+ # for data-parallel
|
|
|
+ text = text[:, : text_lengths.max()]
|
|
|
+ speech = speech[:, :speech_lengths.max()]
|
|
|
+
|
|
|
+ # 1. Encoder
|
|
|
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
|
+ intermediate_outs = None
|
|
|
+ if isinstance(encoder_out, tuple):
|
|
|
+ intermediate_outs = encoder_out[1]
|
|
|
+ encoder_out = encoder_out[0]
|
|
|
+
|
|
|
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
|
|
+ loss_ctc, cer_ctc = None, None
|
|
|
+ loss_pre = None
|
|
|
+ stats = dict()
|
|
|
+
|
|
|
+ # 1. CTC branch
|
|
|
+ if self.ctc_weight != 0.0:
|
|
|
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
|
+ encoder_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
+
|
|
|
+ # Collect CTC branch stats
|
|
|
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
|
+ stats["cer_ctc"] = cer_ctc
|
|
|
+
|
|
|
+ # Intermediate CTC (optional)
|
|
|
+ loss_interctc = 0.0
|
|
|
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
|
+ for layer_idx, intermediate_out in intermediate_outs:
|
|
|
+ # we assume intermediate_out has the same length & padding
|
|
|
+ # as those of encoder_out
|
|
|
+ loss_ic, cer_ic = self._calc_ctc_loss(
|
|
|
+ intermediate_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
+ loss_interctc = loss_interctc + loss_ic
|
|
|
+
|
|
|
+ # Collect Intermedaite CTC stats
|
|
|
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
|
+ loss_ic.detach() if loss_ic is not None else None
|
|
|
+ )
|
|
|
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
|
+
|
|
|
+ loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
+
|
|
|
+ # calculate whole encoder loss
|
|
|
+ loss_ctc = (
|
|
|
+ 1 - self.interctc_weight
|
|
|
+ ) * loss_ctc + self.interctc_weight * loss_interctc
|
|
|
+
|
|
|
+ # 2b. Attention decoder branch
|
|
|
+ if self.ctc_weight != 1.0:
|
|
|
+ loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
|
|
|
+ encoder_out, encoder_out_lens, text, text_lengths
|
|
|
+ )
|
|
|
+
|
|
|
+ # 3. CTC-Att loss definition
|
|
|
+ if self.ctc_weight == 0.0:
|
|
|
+ loss = loss_att + loss_pre * self.predictor_weight
|
|
|
+ elif self.ctc_weight == 1.0:
|
|
|
+ loss = loss_ctc
|
|
|
+ else:
|
|
|
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
|
|
|
+
|
|
|
+ # Collect Attn branch stats
|
|
|
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
|
|
+ stats["acc"] = acc_att
|
|
|
+ stats["cer"] = cer_att
|
|
|
+ stats["wer"] = wer_att
|
|
|
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
|
|
+
|
|
|
+ stats["loss"] = torch.clone(loss.detach())
|
|
|
+
|
|
|
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
|
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
+ return loss, stats, weight
|
|
|
+
|
|
|
+ def encode_chunk(
|
|
|
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
+
|
|
|
+ Args:
|
|
|
+ speech: (Batch, Length, ...)
|
|
|
+ speech_lengths: (Batch, )
|
|
|
+ """
|
|
|
+ with autocast(False):
|
|
|
+ # 1. Extract feats
|
|
|
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
+
|
|
|
+ # 2. Data augmentation
|
|
|
+ if self.specaug is not None and self.training:
|
|
|
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
+
|
|
|
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
+ if self.normalize is not None:
|
|
|
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
+
|
|
|
+ # Pre-encoder, e.g. used for raw input data
|
|
|
+ if self.preencoder is not None:
|
|
|
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
+
|
|
|
+ # 4. Forward encoder
|
|
|
+ # feats: (Batch, Length, Dim)
|
|
|
+ # -> encoder_out: (Batch, Length2, Dim2)
|
|
|
+ if self.encoder.interctc_use_conditioning:
|
|
|
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
|
|
|
+ feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
|
|
|
+ intermediate_outs = None
|
|
|
+ if isinstance(encoder_out, tuple):
|
|
|
+ intermediate_outs = encoder_out[1]
|
|
|
+ encoder_out = encoder_out[0]
|
|
|
+
|
|
|
+ # Post-encoder, e.g. NLU
|
|
|
+ if self.postencoder is not None:
|
|
|
+ encoder_out, encoder_out_lens = self.postencoder(
|
|
|
+ encoder_out, encoder_out_lens
|
|
|
+ )
|
|
|
+
|
|
|
+ if intermediate_outs is not None:
|
|
|
+ return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
+
|
|
|
+ return encoder_out, torch.tensor([encoder_out.size(1)])
|
|
|
+
|
|
|
+ def calc_predictor_chunk(self, encoder_out, cache=None):
|
|
|
+
|
|
|
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
|
|
|
+ self.predictor.forward_chunk(encoder_out, cache["encoder"])
|
|
|
+ return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
|
|
|
+
|
|
|
+ def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
|
|
|
+ decoder_outs = self.decoder.forward_chunk(
|
|
|
+ encoder_out, sematic_embeds, cache["decoder"]
|
|
|
+ )
|
|
|
+ decoder_out = decoder_outs
|
|
|
+ decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
|
|
+ return decoder_out
|
|
|
+
|
|
|
+
|
|
|
class ParaformerBert(Paraformer):
|
|
|
"""
|
|
|
Author: Speech Lab of DAMO Academy, Alibaba Group
|