凌匀 3 лет назад
Родитель
Сommit
4c053ccc39
1 измененных файлов с 3 добавлено и 2 удалено
  1. 3 2
      funasr/models/encoder/fsmn_encoder.py

+ 3 - 2
funasr/models/encoder/fsmn_encoder.py

@@ -82,7 +82,8 @@ class FSMNBlock(nn.Module):
     def forward(self, input: torch.Tensor, cache: torch.Tensor):
         x = torch.unsqueeze(input, 1)
         x_per = x.permute(0, 3, 2, 1)  # B D T C
-
+        
+        cache = cache.to(x_per.device)
         y_left = torch.cat((cache, x_per), dim=2)
         cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
         y_left = self.conv_left(y_left)
@@ -297,4 +298,4 @@ if __name__ == '__main__':
     print('input shape: {}'.format(x.shape))
     print('output shape: {}'.format(y.shape))
 
-    print(fsmn.to_kaldi_net())
+    print(fsmn.to_kaldi_net())