Эх сурвалжийг харах

Merge pull request #367 from alibaba-damo-academy/dev_lhn2

Dev lhn2
hnluo 2 жил өмнө
parent
commit
24f73665e2

+ 1 - 0
funasr/bin/asr_inference_paraformer_streaming.py

@@ -609,6 +609,7 @@ def inference_modelscope(
 
         # 3. Build data-iterator
         is_final = False
+        cache = {}
         if param_dict is not None and "cache" in param_dict:
             cache = param_dict["cache"]
         if param_dict is not None and "is_final" in param_dict:

+ 178 - 63
funasr/models/e2e_asr_paraformer.py

@@ -325,56 +325,6 @@ class Paraformer(AbsESPnetModel):
 
         return encoder_out, encoder_out_lens
 
-    def encode_chunk(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
-
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-            # 2. Data augmentation
-            if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
-
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # Pre-encoder, e.g. used for raw input data
-        if self.preencoder is not None:
-            feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
-        # 4. Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
-                feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
-
-        # Post-encoder, e.g. NLU
-        if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
-
-        if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
-
-        return encoder_out, torch.tensor([encoder_out.size(1)])
-
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
@@ -383,11 +333,6 @@ class Paraformer(AbsESPnetModel):
                                                                                   ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
 
-    def calc_predictor_chunk(self, encoder_out, cache=None):
-
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor.forward_chunk(encoder_out, cache["encoder"])
-        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
     def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
 
         decoder_outs = self.decoder(
@@ -397,14 +342,6 @@ class Paraformer(AbsESPnetModel):
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
 
-    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
-        decoder_outs = self.decoder.forward_chunk(
-            encoder_out, sematic_embeds, cache["decoder"]
-        )
-        decoder_out = decoder_outs
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        return decoder_out
-
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -610,6 +547,184 @@ class Paraformer(AbsESPnetModel):
         return loss_ctc, cer_ctc
 
 
+class ParaformerOnline(Paraformer):
+    """
+    Author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+            self, *args, **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        batch_size = speech.shape[0]
+        self.step_cur += 1
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+        speech = speech[:, :speech_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        loss_pre = None
+        stats = dict()
+
+        # 1. CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+            loss_interctc = loss_interctc / len(intermediate_outs)
+
+            # calculate whole encoder loss
+            loss_ctc = (
+                               1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+
+        # 2b. Attention decoder branch
+        if self.ctc_weight != 1.0:
+            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att + loss_pre * self.predictor_weight
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode_chunk(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+                feats, feats_lengths, cache=cache["encoder"], ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(feats, feats_lengths, cache=cache["encoder"])
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, torch.tensor([encoder_out.size(1)])
+
+    def calc_predictor_chunk(self, encoder_out, cache=None):
+
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \
+            self.predictor.forward_chunk(encoder_out, cache["encoder"])
+        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+        decoder_outs = self.decoder.forward_chunk(
+            encoder_out, sematic_embeds, cache["decoder"]
+        )
+        decoder_out = decoder_outs
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out
+
+
 class ParaformerBert(Paraformer):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group

+ 4 - 2
funasr/models/encoder/sanm_encoder.py

@@ -11,7 +11,7 @@ from typeguard import check_argument_types
 import numpy as np
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
 from funasr.modules.layer_norm import LayerNorm
 from funasr.modules.multi_layer_conv import Conv1dLinear
 from funasr.modules.multi_layer_conv import MultiLayeredConv1d
@@ -180,6 +180,8 @@ class SANMEncoder(AbsEncoder):
                 self.embed = torch.nn.Linear(input_size, output_size)
         elif input_layer == "pe":
             self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -357,7 +359,7 @@ class SANMEncoder(AbsEncoder):
         if self.embed is None:
             xs_pad = xs_pad
         else:
-            xs_pad = self.embed.forward_chunk(xs_pad, cache)
+            xs_pad = self.embed(xs_pad, cache)
 
         encoder_outs = self.encoders0(xs_pad, None, None, None, None)
         xs_pad, masks = encoder_outs[0], encoder_outs[1]

+ 20 - 3
funasr/modules/embedding.py

@@ -407,7 +407,24 @@ class SinusoidalPositionEncoder(torch.nn.Module):
 
         return x + position_encoding
 
-    def forward_chunk(self, x, cache=None):
+class StreamSinusoidalPositionEncoder(torch.nn.Module):
+    '''
+
+    '''
+    def __int__(self, d_model=80, dropout_rate=0.1):
+        pass
+
+    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+        batch_size = positions.size(0)
+        positions = positions.type(dtype)
+        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
+        inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
+        inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
+        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+        encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+        return encoding.type(dtype)
+
+    def forward(self, x, cache=None):
         start_idx = 0
         pad_left = 0
         pad_right = 0
@@ -419,8 +436,8 @@ class SinusoidalPositionEncoder(torch.nn.Module):
         positions = torch.arange(1, timesteps+start_idx+1)[None, :]
         position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
         outputs = x + position_encoding[:, start_idx: start_idx + timesteps]
-        outputs = outputs.transpose(1,2)
+        outputs = outputs.transpose(1, 2)
         outputs = F.pad(outputs, (pad_left, pad_right))
-        outputs = outputs.transpose(1,2)
+        outputs = outputs.transpose(1, 2)
         return outputs
        

+ 2 - 1
funasr/tasks/asr.py

@@ -39,7 +39,7 @@ from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_uni_asr import UniASR
@@ -121,6 +121,7 @@ model_choices = ClassChoices(
         asr=ESPnetASRModel,
         uniasr=UniASR,
         paraformer=Paraformer,
+        paraformer_online=ParaformerOnline,
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,