vgg2l.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. """VGG2L module definition for custom encoder."""
  2. from typing import Tuple, Union
  3. import torch
  4. class VGG2L(torch.nn.Module):
  5. """VGG2L module for custom encoder.
  6. Args:
  7. idim: Input dimension.
  8. odim: Output dimension.
  9. pos_enc: Positional encoding class.
  10. """
  11. def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None):
  12. """Construct a VGG2L object."""
  13. super().__init__()
  14. self.vgg2l = torch.nn.Sequential(
  15. torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
  16. torch.nn.ReLU(),
  17. torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
  18. torch.nn.ReLU(),
  19. torch.nn.MaxPool2d((3, 2)),
  20. torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
  21. torch.nn.ReLU(),
  22. torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
  23. torch.nn.ReLU(),
  24. torch.nn.MaxPool2d((2, 2)),
  25. )
  26. if pos_enc is not None:
  27. self.output = torch.nn.Sequential(
  28. torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc
  29. )
  30. else:
  31. self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
  32. def forward(
  33. self, feats: torch.Tensor, feats_mask: torch.Tensor
  34. ) -> Union[
  35. Tuple[torch.Tensor, torch.Tensor],
  36. Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor],
  37. ]:
  38. """Forward VGG2L bottleneck.
  39. Args:
  40. feats: Feature sequences. (B, F, D_feats)
  41. feats_mask: Mask of feature sequences. (B, 1, F)
  42. Returns:
  43. vgg_output: VGG output sequences.
  44. (B, sub(F), D_out) or ((B, sub(F), D_out), (B, sub(F), D_att))
  45. vgg_mask: Mask of VGG output sequences. (B, 1, sub(F))
  46. """
  47. feats = feats.unsqueeze(1)
  48. vgg_output = self.vgg2l(feats)
  49. b, c, t, f = vgg_output.size()
  50. vgg_output = self.output(
  51. vgg_output.transpose(1, 2).contiguous().view(b, t, c * f)
  52. )
  53. if feats_mask is not None:
  54. vgg_mask = self.create_new_mask(feats_mask)
  55. else:
  56. vgg_mask = feats_mask
  57. return vgg_output, vgg_mask
  58. def create_new_mask(self, feats_mask: torch.Tensor) -> torch.Tensor:
  59. """Create a subsampled mask of feature sequences.
  60. Args:
  61. feats_mask: Mask of feature sequences. (B, 1, F)
  62. Returns:
  63. vgg_mask: Mask of VGG2L output sequences. (B, 1, sub(F))
  64. """
  65. vgg1_t_len = feats_mask.size(2) - (feats_mask.size(2) % 3)
  66. vgg_mask = feats_mask[:, :, :vgg1_t_len][:, :, ::3]
  67. vgg2_t_len = vgg_mask.size(2) - (vgg_mask.size(2) % 2)
  68. vgg_mask = vgg_mask[:, :, :vgg2_t_len][:, :, ::2]
  69. return vgg_mask