|
|
@@ -82,7 +82,8 @@ class FSMNBlock(nn.Module):
|
|
|
def forward(self, input: torch.Tensor, cache: torch.Tensor):
|
|
|
x = torch.unsqueeze(input, 1)
|
|
|
x_per = x.permute(0, 3, 2, 1) # B D T C
|
|
|
-
|
|
|
+
|
|
|
+ cache = cache.to(x_per.device)
|
|
|
y_left = torch.cat((cache, x_per), dim=2)
|
|
|
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
|
|
y_left = self.conv_left(y_left)
|
|
|
@@ -297,4 +298,4 @@ if __name__ == '__main__':
|
|
|
print('input shape: {}'.format(x.shape))
|
|
|
print('output shape: {}'.format(y.shape))
|
|
|
|
|
|
- print(fsmn.to_kaldi_net())
|
|
|
+ print(fsmn.to_kaldi_net())
|