Ver Fonte

add sond model

志浩 há 3 anos atrás
pai
commit
7fe447185c
1 ficheiros alterados com 52 adições e 0 exclusões
  1. 52 0
      funasr/modules/multi_layer_conv.py

+ 52 - 0
funasr/modules/multi_layer_conv.py

@@ -63,6 +63,58 @@ class MultiLayeredConv1d(torch.nn.Module):
         return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
 
 
+class FsmnFeedForward(torch.nn.Module):
+    """Position-wise feed forward for FSMN blocks.
+
+    This is a module of multi-leyered conv1d designed
+    to replace position-wise feed-forward network
+    in FSMN block.
+    """
+
+    def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
+        """Initialize FsmnFeedForward module.
+
+        Args:
+            in_chans (int): Number of input channels.
+            hidden_chans (int): Number of hidden channels.
+            out_chans (int): Number of output channels.
+            kernel_size (int): Kernel size of conv1d.
+            dropout_rate (float): Dropout rate.
+
+        """
+        super(FsmnFeedForward, self).__init__()
+        self.w_1 = torch.nn.Conv1d(
+            in_chans,
+            hidden_chans,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+        )
+        self.w_2 = torch.nn.Conv1d(
+            hidden_chans,
+            out_chans,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            bias=False
+        )
+        self.norm = torch.nn.LayerNorm(hidden_chans)
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+    def forward(self, x, ilens=None):
+        """Calculate forward propagation.
+
+        Args:
+            x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+
+        Returns:
+            torch.Tensor: Batch of output tensors (B, T, out_chans).
+
+        """
+        x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+        return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
+
+
 class Conv1dLinear(torch.nn.Module):
     """Conv1D + Linear for Transformer block.