|
|
@@ -427,6 +427,7 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
conv_size: Union[int, Tuple],
|
|
|
subsampling_factor: int = 4,
|
|
|
vgg_like: bool = True,
|
|
|
+ conv_kernel_size: int = 3,
|
|
|
output_size: Optional[int] = None,
|
|
|
) -> None:
|
|
|
"""Construct a ConvInput object."""
|
|
|
@@ -436,14 +437,14 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
conv_size1, conv_size2 = conv_size
|
|
|
|
|
|
self.conv = torch.nn.Sequential(
|
|
|
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
torch.nn.MaxPool2d((1, 2)),
|
|
|
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
torch.nn.MaxPool2d((1, 2)),
|
|
|
)
|
|
|
@@ -462,14 +463,14 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
kernel_1 = int(subsampling_factor / 2)
|
|
|
|
|
|
self.conv = torch.nn.Sequential(
|
|
|
- torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
torch.nn.MaxPool2d((kernel_1, 2)),
|
|
|
- torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
|
|
|
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
|
|
|
torch.nn.ReLU(),
|
|
|
torch.nn.MaxPool2d((2, 2)),
|
|
|
)
|
|
|
@@ -487,14 +488,14 @@ class StreamingConvInput(torch.nn.Module):
|
|
|
self.conv = torch.nn.Sequential(
|
|
|
torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
|
|
|
torch.nn.ReLU(),
|
|
|
- torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
|
|
|
+ torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
|
|
|
torch.nn.ReLU(),
|
|
|
)
|
|
|
|
|
|
output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
|
|
|
|
|
|
self.subsampling_factor = subsampling_factor
|
|
|
- self.kernel_2 = 3
|
|
|
+ self.kernel_2 = conv_kernel_size
|
|
|
self.stride_2 = 1
|
|
|
|
|
|
self.create_new_mask = self.create_new_conv2d_mask
|