| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import torch
- import torch.nn as nn
- class MossFormerDecoder(nn.ConvTranspose1d):
- """A decoder layer that consists of ConvTranspose1d.
- Arguments
- ---------
- kernel_size : int
- Length of filters.
- in_channels : int
- Number of input channels.
- out_channels : int
- Number of output channels.
- Example
- ---------
- >>> x = torch.randn(2, 100, 1000)
- >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
- >>> h = decoder(x)
- >>> h.shape
- torch.Size([2, 1003])
- """
- def __init__(self, *args, **kwargs):
- super(MossFormerDecoder, self).__init__(*args, **kwargs)
- def forward(self, x):
- """Return the decoded output.
- Arguments
- ---------
- x : torch.Tensor
- Input tensor with dimensionality [B, N, L].
- where, B = Batchsize,
- N = number of filters
- L = time points
- """
- if x.dim() not in [2, 3]:
- raise RuntimeError(
- "{} accept 3/4D tensor as input".format(self.__name__)
- )
- x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
- if torch.squeeze(x).dim() == 1:
- x = torch.squeeze(x, dim=1)
- else:
- x = torch.squeeze(x)
- return x
|