|
|
@@ -506,9 +506,9 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
)
|
|
|
|
|
|
self.conv = torch.nn.Sequential(
|
|
|
- torch.nn.Conv2d(1, conv_size, 3, 2),
|
|
|
+ torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
|
|
|
+ torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
|
|
|
torch.nn.ReLU(),
|
|
|
)
|
|
|
|
|
|
@@ -597,7 +597,7 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
mask: Mask of output sequences. (B, sub(T))
|
|
|
"""
|
|
|
if self.subsampling_factor > 1:
|
|
|
- return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
|
|
|
+ return mask[:, ::2][:, ::self.stride_2]
|
|
|
else:
|
|
|
return mask
|
|
|
|