游雁 3 سال پیش
والد
کامیت
795b6e0486
2فایلهای تغییر یافته به همراه8 افزوده شده و 9 حذف شده
  1. 1 6
      funasr/export/models/encoder/sanm_encoder.py
  2. 7 3
      funasr/export/models/vad_realtime_transformer.py

+ 1 - 6
funasr/export/models/encoder/sanm_encoder.py

@@ -151,12 +151,7 @@ class SANMVadEncoder(nn.Module):
     
     def prepare_mask(self, mask, sub_masks):
         mask_3d_btd = mask[:, :, None]
-        # sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32)
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - sub_masks[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - sub_masks[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+        mask_4d_bhlt = (1 - sub_masks) * -10000.0
         
         return mask_3d_btd, mask_4d_bhlt
     

+ 7 - 3
funasr/export/models/vad_realtime_transformer.py

@@ -63,11 +63,11 @@ class VadRealtimeTransformer(nn.Module):
         text_lengths = torch.tensor([length], dtype=torch.int32)
         vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
         sub_masks = torch.ones(length, length, dtype=torch.float32)
-        sub_masks = torch.tril(sub_masks)
-        return (text_indexes, text_lengths, vad_mask, sub_masks)
+        sub_masks = torch.tril(sub_masks).type(torch.float32)
+        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
 
     def get_input_names(self):
-        return ['input', 'text_lengths', 'vad_mask']
+        return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
 
     def get_output_names(self):
         return ['logits']
@@ -81,6 +81,10 @@ class VadRealtimeTransformer(nn.Module):
                 2: 'feats_length1',
                 3: 'feats_length2'
             },
+            'sub_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
             'logits': {
                 1: 'logits_length'
             },