mossformer_decoder.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import torch
  2. import torch.nn as nn
  3. class MossFormerDecoder(nn.ConvTranspose1d):
  4. """A decoder layer that consists of ConvTranspose1d.
  5. Arguments
  6. ---------
  7. kernel_size : int
  8. Length of filters.
  9. in_channels : int
  10. Number of input channels.
  11. out_channels : int
  12. Number of output channels.
  13. Example
  14. ---------
  15. >>> x = torch.randn(2, 100, 1000)
  16. >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
  17. >>> h = decoder(x)
  18. >>> h.shape
  19. torch.Size([2, 1003])
  20. """
  21. def __init__(self, *args, **kwargs):
  22. super(MossFormerDecoder, self).__init__(*args, **kwargs)
  23. def forward(self, x):
  24. """Return the decoded output.
  25. Arguments
  26. ---------
  27. x : torch.Tensor
  28. Input tensor with dimensionality [B, N, L].
  29. where, B = Batchsize,
  30. N = number of filters
  31. L = time points
  32. """
  33. if x.dim() not in [2, 3]:
  34. raise RuntimeError(
  35. "{} accept 3/4D tensor as input".format(self.__name__)
  36. )
  37. x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
  38. if torch.squeeze(x).dim() == 1:
  39. x = torch.squeeze(x, dim=1)
  40. else:
  41. x = torch.squeeze(x)
  42. return x