| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
- # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """ This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
- import torch
- import torch.nn as nn
- class TAP(nn.Module):
- """
- Temporal average pooling, only first-order mean is considered
- """
- def __init__(self, **kwargs):
- super(TAP, self).__init__()
- def forward(self, x):
- pooling_mean = x.mean(dim=-1)
- # To be compatable with 2D input
- pooling_mean = pooling_mean.flatten(start_dim=1)
- return pooling_mean
- class TSDP(nn.Module):
- """
- Temporal standard deviation pooling, only second-order std is considered
- """
- def __init__(self, **kwargs):
- super(TSDP, self).__init__()
- def forward(self, x):
- # The last dimension is the temporal axis
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
- pooling_std = pooling_std.flatten(start_dim=1)
- return pooling_std
- class TSTP(nn.Module):
- """
- Temporal statistics pooling, concatenate mean and std, which is used in
- x-vector
- Comment: simple concatenation can not make full use of both statistics
- """
- def __init__(self, **kwargs):
- super(TSTP, self).__init__()
- def forward(self, x):
- # The last dimension is the temporal axis
- pooling_mean = x.mean(dim=-1)
- pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
- pooling_mean = pooling_mean.flatten(start_dim=1)
- pooling_std = pooling_std.flatten(start_dim=1)
- stats = torch.cat((pooling_mean, pooling_std), 1)
- return stats
- class ASTP(nn.Module):
- """ Attentive statistics pooling: Channel- and context-dependent
- statistics pooling, first used in ECAPA_TDNN.
- """
- def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
- super(ASTP, self).__init__()
- self.global_context_att = global_context_att
- # Use Conv1d with stride == 1 rather than Linear, then we don't
- # need to transpose inputs.
- if global_context_att:
- self.linear1 = nn.Conv1d(
- in_dim * 3, bottleneck_dim,
- kernel_size=1) # equals W and b in the paper
- else:
- self.linear1 = nn.Conv1d(
- in_dim, bottleneck_dim,
- kernel_size=1) # equals W and b in the paper
- self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
- kernel_size=1) # equals V and k in the paper
- def forward(self, x):
- """
- x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
- or a 4-dimensional tensor in resnet architecture (B,C,F,T)
- 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
- """
- if len(x.shape) == 4:
- x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
- assert len(x.shape) == 3
- if self.global_context_att:
- context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
- context_std = torch.sqrt(
- torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
- x_in = torch.cat((x, context_mean, context_std), dim=1)
- else:
- x_in = x
- # DON'T use ReLU here! ReLU may be hard to converge.
- alpha = torch.tanh(
- self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
- alpha = torch.softmax(self.linear2(alpha), dim=2)
- mean = torch.sum(alpha * x, dim=2)
- var = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
- std = torch.sqrt(var.clamp(min=1e-10))
- return torch.cat([mean, std], dim=1)
|