Просмотр исходного кода

add paraformer online infer and finetune

haoneng.lhn 2 лет назад
Родитель
Сommit
84b4a01979

+ 3 - 1
funasr/bin/asr_inference_launch.py

@@ -1618,6 +1618,8 @@ def inference_launch(**kwargs):
         return inference_uniasr(**kwargs)
         return inference_uniasr(**kwargs)
     elif mode == "paraformer":
     elif mode == "paraformer":
         return inference_paraformer(**kwargs)
         return inference_paraformer(**kwargs)
+    elif mode == "paraformer_online":
+        return inference_paraformer(**kwargs)
     elif mode == "paraformer_streaming":
     elif mode == "paraformer_streaming":
         return inference_paraformer_online(**kwargs)
         return inference_paraformer_online(**kwargs)
     elif mode.startswith("paraformer_vad"):
     elif mode.startswith("paraformer_vad"):
@@ -1900,4 +1902,4 @@ def main(cmd=None):
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    main()
+    main()

+ 3 - 3
funasr/models/decoder/sanm_decoder.py

@@ -956,14 +956,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         """
         """
         tgt = ys_in_pad
         tgt = ys_in_pad
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         if chunk_mask is not None:
         if chunk_mask is not None:
             memory_mask = memory_mask * chunk_mask
             memory_mask = memory_mask * chunk_mask
             if tgt_mask.size(1) != memory_mask.size(1):
             if tgt_mask.size(1) != memory_mask.size(1):
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
 
-        memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
         x = tgt
         x = tgt
         x, tgt_mask, memory, memory_mask, _ = self.decoders(
         x, tgt_mask, memory, memory_mask, _ = self.decoders(
             x, tgt_mask, memory, memory_mask
             x, tgt_mask, memory, memory_mask

+ 93 - 3
funasr/models/e2e_asr_paraformer.py

@@ -279,7 +279,7 @@ class Paraformer(FunASRModel):
 
 
     def encode(
     def encode(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
-    ) -> Tuple[Tuple[Any, Optional[Any]], Any]:
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
         Args:
                 speech: (Batch, Length, ...)
                 speech: (Batch, Length, ...)
@@ -649,7 +649,35 @@ class ParaformerOnline(Paraformer):
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
 
-        super().__init__()
+        super().__init__(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
+        )
         # note that eos is the same as sos (equivalent ID)
         # note that eos is the same as sos (equivalent ID)
         self.blank_id = blank_id
         self.blank_id = blank_id
         self.sos = vocab_size - 1 if sos is None else sos
         self.sos = vocab_size - 1 if sos is None else sos
@@ -705,6 +733,7 @@ class ParaformerOnline(Paraformer):
         self.sampling_ratio = sampling_ratio
         self.sampling_ratio = sampling_ratio
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
         self.step_cur = 0
         self.step_cur = 0
+        self.scama_mask = None
         if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
         if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
             from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
             from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
             self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
             self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -859,7 +888,7 @@ class ParaformerOnline(Paraformer):
         # Pre-encoder, e.g. used for raw input data
         # Pre-encoder, e.g. used for raw input data
         if self.preencoder is not None:
         if self.preencoder is not None:
             feats, feats_lengths = self.preencoder(feats, feats_lengths)
             feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
+        
         # 4. Forward encoder
         # 4. Forward encoder
         # feats: (Batch, Length, Dim)
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
         # -> encoder_out: (Batch, Length2, Dim2)
@@ -1111,12 +1140,73 @@ class ParaformerOnline(Paraformer):
 
 
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
         return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
 
 
+    def calc_predictor(self, encoder_out, encoder_out_lens):
+
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        mask_chunk_predictor = None
+        if self.encoder.overlap_chunk_cls is not None:
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                           device=encoder_out.device,
+                                                                                           batch_size=encoder_out.size(
+                                                                                               0))
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                                   batch_size=encoder_out.size(0))
+            encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+                                                                                           None,
+                                                                                           encoder_out_mask,
+                                                                                           ignore_id=self.ignore_id,
+                                                                                           mask_chunk_predictor=mask_chunk_predictor,
+                                                                                           target_label_length=None,
+                                                                                           )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas[:, :-1],
+                                                                                             encoder_out_lens)
+
+        scama_mask = None
+        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+            attention_chunk_center_bias = 0
+            attention_chunk_size = encoder_chunk_size
+            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+                get_mask_shift_att_chunk_decoder(None,
+                                                 device=encoder_out.device,
+                                                 batch_size=encoder_out.size(0)
+                                                 )
+            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+                predictor_alignments=predictor_alignments,
+                encoder_sequence_length=encoder_out_lens,
+                chunk_size=1,
+                encoder_chunk_size=encoder_chunk_size,
+                attention_chunk_center_bias=attention_chunk_center_bias,
+                attention_chunk_size=attention_chunk_size,
+                attention_chunk_type=self.decoder_attention_chunk_type,
+                step=None,
+                predictor_mask_chunk_hopping=mask_chunk_predictor,
+                decoder_att_look_back_factor=decoder_att_look_back_factor,
+                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+                target_length=None,
+                is_training=self.training,
+            )
+        self.scama_mask = scama_mask
+
+        return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
     def calc_predictor_chunk(self, encoder_out, cache=None):
     def calc_predictor_chunk(self, encoder_out, cache=None):
 
 
         pre_acoustic_embeds, pre_token_length = \
         pre_acoustic_embeds, pre_token_length = \
             self.predictor.forward_chunk(encoder_out, cache["encoder"])
             self.predictor.forward_chunk(encoder_out, cache["encoder"])
         return pre_acoustic_embeds, pre_token_length
         return pre_acoustic_embeds, pre_token_length
 
 
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+        )
+        decoder_out = decoder_outs[0]
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out, ys_pad_lens
+
     def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
     def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
         decoder_outs = self.decoder.forward_chunk(
         decoder_outs = self.decoder.forward_chunk(
             encoder_out, sematic_embeds, cache["decoder"]
             encoder_out, sematic_embeds, cache["decoder"]