游雁 3 лет назад
Родитель
Сommit
5b0047bf58

+ 8 - 4
funasr/export/models/decoder/sanm_decoder.py

@@ -4,9 +4,8 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
 
 
-# from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
-
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
 
 
 from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
 from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
 from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
 from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
@@ -20,11 +19,15 @@ from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as Decod
 class ParaformerSANMDecoder(nn.Module):
 class ParaformerSANMDecoder(nn.Module):
     def __init__(self, model,
     def __init__(self, model,
                  max_seq_len=512,
                  max_seq_len=512,
-                 model_name='decoder'):
+                 model_name='decoder',
+                 onnx: bool = True,):
         super().__init__()
         super().__init__()
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
         self.model = model
         self.model = model
-        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
 
 
         for i, d in enumerate(self.model.decoders):
         for i, d in enumerate(self.model.decoders):
             if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
             if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
@@ -51,6 +54,7 @@ class ParaformerSANMDecoder(nn.Module):
         self.output_layer = model.output_layer
         self.output_layer = model.output_layer
         self.after_norm = model.after_norm
         self.after_norm = model.after_norm
         self.model_name = model_name
         self.model_name = model_name
+        
 
 
     def prepare_mask(self, mask):
     def prepare_mask(self, mask):
         mask_3d_btd = mask[:, :, None]
         mask_3d_btd = mask[:, :, None]

+ 14 - 9
funasr/export/models/e2e_asr_paraformer.py

@@ -5,7 +5,7 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import MakePadMask
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.export.utils.torch_function import sequence_mask
 from funasr.models.encoder.sanm_encoder import SANMEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder
 from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
 from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
 from funasr.models.predictor.cif import CifPredictorV2
 from funasr.models.predictor.cif import CifPredictorV2
@@ -29,19 +29,24 @@ class Paraformer(nn.Module):
             **kwargs,
             **kwargs,
     ):
     ):
         super().__init__()
         super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
         if isinstance(model.encoder, SANMEncoder):
         if isinstance(model.encoder, SANMEncoder):
-            self.encoder = SANMEncoder_export(model.encoder)
+            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
         if isinstance(model.predictor, CifPredictorV2):
         if isinstance(model.predictor, CifPredictorV2):
             self.predictor = CifPredictorV2_export(model.predictor)
             self.predictor = CifPredictorV2_export(model.predictor)
         if isinstance(model.decoder, ParaformerSANMDecoder):
         if isinstance(model.decoder, ParaformerSANMDecoder):
-            self.decoder = ParaformerSANMDecoder_export(model.decoder)
-        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
+        
         self.feats_dim = feats_dim
         self.feats_dim = feats_dim
         self.model_name = model_name
         self.model_name = model_name
-        self.onnx = False
-        if "onnx" in kwargs:
-            self.onnx = kwargs["onnx"]
-    
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
     def forward(
     def forward(
             self,
             self,
             speech: torch.Tensor,
             speech: torch.Tensor,
@@ -66,7 +71,7 @@ class Paraformer(nn.Module):
 
 
     def get_dummy_inputs(self):
     def get_dummy_inputs(self):
         speech = torch.randn(2, 30, self.feats_dim)
         speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30]).long()
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
         return (speech, speech_lengths)
         return (speech, speech_lengths)
 
 
     def get_input_names(self):
     def get_input_names(self):

+ 7 - 1
funasr/export/models/encoder/sanm_encoder.py

@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
 from funasr.modules.attention import MultiHeadedAttentionSANM
 from funasr.modules.attention import MultiHeadedAttentionSANM
 from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
 from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
 from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
 from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
@@ -15,13 +16,18 @@ class SANMEncoder(nn.Module):
         max_seq_len=512,
         max_seq_len=512,
         feats_dim=560,
         feats_dim=560,
         model_name='encoder',
         model_name='encoder',
+        onnx: bool = True,
     ):
     ):
         super().__init__()
         super().__init__()
         self.embed = model.embed
         self.embed = model.embed
         self.model = model
         self.model = model
-        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
         self.feats_dim = feats_dim
         self.feats_dim = feats_dim
 
 
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
         if hasattr(model, 'encoders0'):
         if hasattr(model, 'encoders0'):
             for i, d in enumerate(self.model.encoders0):
             for i, d in enumerate(self.model.encoders0):
                 if isinstance(d.self_attn, MultiHeadedAttentionSANM):
                 if isinstance(d.self_attn, MultiHeadedAttentionSANM):

+ 1 - 1
funasr/export/test_onnx.py

@@ -9,7 +9,7 @@ if __name__ == '__main__':
     output_name = [nd.name for nd in sess.get_outputs()]
     output_name = [nd.name for nd in sess.get_outputs()]
 
 
     def _get_feed_dict(feats_length):
     def _get_feed_dict(feats_length):
-        return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': [feats_length,]}
+        return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': np.array([feats_length,], dtype=np.int64)}
 
 
     def _run(feed_dict):
     def _run(feed_dict):
         output = sess.run(output_name, input_feed=feed_dict)
         output = sess.run(output_name, input_feed=feed_dict)

+ 12 - 0
funasr/export/utils/torch_function.py

@@ -44,6 +44,18 @@ class MakePadMask(nn.Module):
         else:
         else:
             return mask
             return mask
 
 
+class sequence_mask(nn.Module):
+    def __init__(self, max_seq_len=512, flip=True):
+        super().__init__()
+    
+    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
+        if max_seq_len is None:
+            max_seq_len = lengths.max()
+        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
+        matrix = torch.unsqueeze(lengths, dim=-1)
+        mask = row_vector < matrix
+        
+        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
 
 
 def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
 def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
     if out is None:
     if out is None: