Jelajahi Sumber

fix decoder cache

haoneng.lhn 3 tahun lalu
induk
melakukan
6be782d9fd
1 mengubah file dengan 7 tambahan dan 7 penghapusan
  1. 7 7
      funasr/models/decoder/sanm_decoder.py

+ 7 - 7
funasr/models/decoder/sanm_decoder.py

@@ -94,7 +94,7 @@ class DecoderLayerSANM(nn.Module):
         if self.self_attn:
         if self.self_attn:
             if self.normalize_before:
             if self.normalize_before:
                 tgt = self.norm2(tgt)
                 tgt = self.norm2(tgt)
-            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x, _ = self.self_attn(tgt, tgt_mask)
             x = residual + self.dropout(x)
             x = residual + self.dropout(x)
 
 
         if self.src_attn is not None:
         if self.src_attn is not None:
@@ -399,7 +399,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
         for i in range(self.att_layer_num):
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             decoder = self.decoders[i]
             c = cache[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, memory_mask, cache=c
                 x, tgt_mask, memory, memory_mask, cache=c
             )
             )
             new_cache.append(c_ret)
             new_cache.append(c_ret)
@@ -409,13 +409,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
                 j = i + self.att_layer_num
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 decoder = self.decoders2[i]
                 c = cache[j]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, memory_mask, cache=c
                     x, tgt_mask, memory, memory_mask, cache=c
                 )
                 )
                 new_cache.append(c_ret)
                 new_cache.append(c_ret)
 
 
         for decoder in self.decoders3:
         for decoder in self.decoders3:
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
                 x, tgt_mask, memory, None, cache=None
             )
             )
 
 
@@ -1076,7 +1076,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         for i in range(self.att_layer_num):
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             decoder = self.decoders[i]
             c = cache[i]
             c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder(
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=c
                 x, tgt_mask, memory, None, cache=c
             )
             )
             new_cache.append(c_ret)
             new_cache.append(c_ret)
@@ -1086,14 +1086,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
                 j = i + self.att_layer_num
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
                 decoder = self.decoders2[i]
                 c = cache[j]
                 c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder(
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
                     x, tgt_mask, memory, None, cache=c
                     x, tgt_mask, memory, None, cache=c
                 )
                 )
                 new_cache.append(c_ret)
                 new_cache.append(c_ret)
 
 
         for decoder in self.decoders3:
         for decoder in self.decoders3:
 
 
-            x, tgt_mask, memory, memory_mask, _ = decoder(
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
                 x, tgt_mask, memory, None, cache=None
                 x, tgt_mask, memory, None, cache=None
             )
             )