|
|
@@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module):
|
|
|
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
|
|
|
|
|
|
|
|
+class FsmnFeedForward(torch.nn.Module):
|
|
|
+ """Position-wise feed forward for FSMN blocks.
|
|
|
+
|
|
|
+ This is a module of multi-leyered conv1d designed
|
|
|
+ to replace position-wise feed-forward network
|
|
|
+ in FSMN block.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
|
|
|
+ """Initialize FsmnFeedForward module.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ in_chans (int): Number of input channels.
|
|
|
+ hidden_chans (int): Number of hidden channels.
|
|
|
+ out_chans (int): Number of output channels.
|
|
|
+ kernel_size (int): Kernel size of conv1d.
|
|
|
+ dropout_rate (float): Dropout rate.
|
|
|
+
|
|
|
+ """
|
|
|
+ super(FsmnFeedForward, self).__init__()
|
|
|
+ self.w_1 = torch.nn.Conv1d(
|
|
|
+ in_chans,
|
|
|
+ hidden_chans,
|
|
|
+ kernel_size,
|
|
|
+ stride=1,
|
|
|
+ padding=(kernel_size - 1) // 2,
|
|
|
+ )
|
|
|
+ self.w_2 = torch.nn.Conv1d(
|
|
|
+ hidden_chans,
|
|
|
+ out_chans,
|
|
|
+ kernel_size,
|
|
|
+ stride=1,
|
|
|
+ padding=(kernel_size - 1) // 2,
|
|
|
+ bias=False
|
|
|
+ )
|
|
|
+ self.norm = torch.nn.LayerNorm(hidden_chans)
|
|
|
+ self.dropout = torch.nn.Dropout(dropout_rate)
|
|
|
+
|
|
|
+ def forward(self, x, ilens=None):
|
|
|
+ """Calculate forward propagation.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ torch.Tensor: Batch of output tensors (B, T, out_chans).
|
|
|
+
|
|
|
+ """
|
|
|
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
|
|
+ return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
|
|
|
+
|
|
|
+
|
|
|
class Conv1dLinear(torch.nn.Module):
|
|
|
"""Conv1D + Linear for Transformer block.
|
|
|
|