嘉渊 2 years ago
parent
commit
f701679677
1 changed files with 6 additions and 4 deletions
  1. 6 4
      funasr/models/encoder/sanm_encoder.py

+ 6 - 4
funasr/models/encoder/sanm_encoder.py

@@ -25,9 +25,11 @@ from funasr.modules.subsampling import Conv2dSubsampling6
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
 from funasr.modules.mask import subsequent_mask, vad_mask
 
+from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
+
 class EncoderLayerSANM(nn.Module):
     def __init__(
         self,
@@ -114,7 +116,7 @@ class EncoderLayerSANM(nn.Module):
 
         return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
 
-class SANMEncoder(torch.nn.Module):
+class SANMEncoder(AbsEncoder):
     """
     author: Speech Lab, Alibaba Group, China
     San-m: Memory equipped self-attention for end-to-end speech recognition
@@ -546,7 +548,7 @@ class SANMEncoder(torch.nn.Module):
         return var_dict_torch_update
 
 
-class SANMEncoderChunkOpt(torch.nn.Module):
+class SANMEncoderChunkOpt(AbsEncoder):
     """
     author: Speech Lab, Alibaba Group, China
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
@@ -959,7 +961,7 @@ class SANMEncoderChunkOpt(torch.nn.Module):
         return var_dict_torch_update
 
 
-class SANMVadEncoder(torch.nn.Module):
+class SANMVadEncoder(AbsEncoder):
     """
     author: Speech Lab, Alibaba Group, China