pooling_layers.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
  2. # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
  4. import torch
  5. import torch.nn as nn
  6. class TAP(nn.Module):
  7. """
  8. Temporal average pooling, only first-order mean is considered
  9. """
  10. def __init__(self, **kwargs):
  11. super(TAP, self).__init__()
  12. def forward(self, x):
  13. pooling_mean = x.mean(dim=-1)
  14. # To be compatable with 2D input
  15. pooling_mean = pooling_mean.flatten(start_dim=1)
  16. return pooling_mean
  17. class TSDP(nn.Module):
  18. """
  19. Temporal standard deviation pooling, only second-order std is considered
  20. """
  21. def __init__(self, **kwargs):
  22. super(TSDP, self).__init__()
  23. def forward(self, x):
  24. # The last dimension is the temporal axis
  25. pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
  26. pooling_std = pooling_std.flatten(start_dim=1)
  27. return pooling_std
  28. class TSTP(nn.Module):
  29. """
  30. Temporal statistics pooling, concatenate mean and std, which is used in
  31. x-vector
  32. Comment: simple concatenation can not make full use of both statistics
  33. """
  34. def __init__(self, **kwargs):
  35. super(TSTP, self).__init__()
  36. def forward(self, x):
  37. # The last dimension is the temporal axis
  38. pooling_mean = x.mean(dim=-1)
  39. pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
  40. pooling_mean = pooling_mean.flatten(start_dim=1)
  41. pooling_std = pooling_std.flatten(start_dim=1)
  42. stats = torch.cat((pooling_mean, pooling_std), 1)
  43. return stats
  44. class ASTP(nn.Module):
  45. """ Attentive statistics pooling: Channel- and context-dependent
  46. statistics pooling, first used in ECAPA_TDNN.
  47. """
  48. def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
  49. super(ASTP, self).__init__()
  50. self.global_context_att = global_context_att
  51. # Use Conv1d with stride == 1 rather than Linear, then we don't
  52. # need to transpose inputs.
  53. if global_context_att:
  54. self.linear1 = nn.Conv1d(
  55. in_dim * 3, bottleneck_dim,
  56. kernel_size=1) # equals W and b in the paper
  57. else:
  58. self.linear1 = nn.Conv1d(
  59. in_dim, bottleneck_dim,
  60. kernel_size=1) # equals W and b in the paper
  61. self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
  62. kernel_size=1) # equals V and k in the paper
  63. def forward(self, x):
  64. """
  65. x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
  66. or a 4-dimensional tensor in resnet architecture (B,C,F,T)
  67. 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
  68. """
  69. if len(x.shape) == 4:
  70. x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
  71. assert len(x.shape) == 3
  72. if self.global_context_att:
  73. context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
  74. context_std = torch.sqrt(
  75. torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
  76. x_in = torch.cat((x, context_mean, context_std), dim=1)
  77. else:
  78. x_in = x
  79. # DON'T use ReLU here! ReLU may be hard to converge.
  80. alpha = torch.tanh(
  81. self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
  82. alpha = torch.softmax(self.linear2(alpha), dim=2)
  83. mean = torch.sum(alpha * x, dim=2)
  84. var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
  85. std = torch.sqrt(var.clamp(min=1e-10))
  86. return torch.cat([mean, std], dim=1)