| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436 |
- from typing import Optional
- from typing import Tuple
- import logging
- import torch
- from torch import nn
- from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
- from funasr.models.transformer.utils.nets_utils import get_activation
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- from funasr.models.transformer.attention import (
- MultiHeadedAttention, # noqa: H301
- RelPositionMultiHeadedAttention, # noqa: H301
- LegacyRelPositionMultiHeadedAttention, # noqa: H301
- )
- from funasr.models.transformer.embedding import (
- PositionalEncoding, # noqa: H301
- ScaledPositionalEncoding, # noqa: H301
- RelPositionalEncoding, # noqa: H301
- LegacyRelPositionalEncoding, # noqa: H301
- )
- from funasr.models.transformer.layer_norm import LayerNorm
- from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
- from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
- from funasr.models.transformer.positionwise_feed_forward import (
- PositionwiseFeedForward, # noqa: H301
- )
- from funasr.models.transformer.utils.repeat import repeat
- from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
- from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
- from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
- from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
- from funasr.models.transformer.utils.subsampling import TooShortUttError
- from funasr.models.transformer.utils.subsampling import check_short_utt
- from funasr.models.encoder.abs_encoder import AbsEncoder
- import pdb
- import math
- class ConvolutionModule(nn.Module):
- """ConvolutionModule in Conformer model.
- Args:
- channels (int): The number of channels of conv layers.
- kernel_size (int): Kernerl size of conv layers.
- """
- def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
- """Construct an ConvolutionModule object."""
- super(ConvolutionModule, self).__init__()
- # kernerl_size should be a odd number for 'SAME' padding
- assert (kernel_size - 1) % 2 == 0
- self.pointwise_conv1 = nn.Conv1d(
- channels,
- 2 * channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=bias,
- )
- self.depthwise_conv = nn.Conv1d(
- channels,
- channels,
- kernel_size,
- stride=1,
- padding=(kernel_size - 1) // 2,
- groups=channels,
- bias=bias,
- )
- self.norm = nn.BatchNorm1d(channels)
- self.pointwise_conv2 = nn.Conv1d(
- channels,
- channels,
- kernel_size=1,
- stride=1,
- padding=0,
- bias=bias,
- )
- self.activation = activation
- def forward(self, x):
- """Compute convolution module.
- Args:
- x (torch.Tensor): Input tensor (#batch, time, channels).
- Returns:
- torch.Tensor: Output tensor (#batch, time, channels).
- """
- # exchange the temporal dimension and the feature dimension
- x = x.transpose(1, 2)
- # GLU mechanism
- x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
- x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
- # 1D Depthwise Conv
- x = self.depthwise_conv(x)
- x = self.activation(self.norm(x))
- x = self.pointwise_conv2(x)
- return x.transpose(1, 2)
- class MFCCAEncoder(AbsEncoder):
- """Conformer encoder module.
- Args:
- input_size (int): Input dimension.
- output_size (int): Dimention of attention.
- attention_heads (int): The number of heads of multi head attention.
- linear_units (int): The number of units of position-wise feed forward.
- num_blocks (int): The number of decoder blocks.
- dropout_rate (float): Dropout rate.
- attention_dropout_rate (float): Dropout rate in attention.
- positional_dropout_rate (float): Dropout rate after adding positional encoding.
- input_layer (Union[str, torch.nn.Module]): Input layer type.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- If True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- If False, no additional linear will be applied. i.e. x -> x + att(x)
- positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
- positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
- rel_pos_type (str): Whether to use the latest relative positional encoding or
- the legacy one. The legacy relative positional encoding will be deprecated
- in the future. More Details can be found in
- https://github.com/espnet/espnet/pull/2816.
- encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
- encoder_attn_layer_type (str): Encoder attention layer type.
- activation_type (str): Encoder activation function type.
- macaron_style (bool): Whether to use macaron style for positionwise layer.
- use_cnn_module (bool): Whether to use convolution module.
- zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
- cnn_module_kernel (int): Kernerl size of convolution module.
- padding_idx (int): Padding idx for input_layer=embed.
- """
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
- ):
- super().__init__()
- self._output_size = output_size
- if rel_pos_type == "legacy":
- if pos_enc_layer_type == "rel_pos":
- pos_enc_layer_type = "legacy_rel_pos"
- if selfattention_layer_type == "rel_selfattn":
- selfattention_layer_type = "legacy_rel_selfattn"
- elif rel_pos_type == "latest":
- assert selfattention_layer_type != "legacy_rel_selfattn"
- assert pos_enc_layer_type != "legacy_rel_pos"
- else:
- raise ValueError("unknown rel_pos_type: " + rel_pos_type)
- activation = get_activation(activation_type)
- if pos_enc_layer_type == "abs_pos":
- pos_enc_class = PositionalEncoding
- elif pos_enc_layer_type == "scaled_abs_pos":
- pos_enc_class = ScaledPositionalEncoding
- elif pos_enc_layer_type == "rel_pos":
- assert selfattention_layer_type == "rel_selfattn"
- pos_enc_class = RelPositionalEncoding
- elif pos_enc_layer_type == "legacy_rel_pos":
- assert selfattention_layer_type == "legacy_rel_selfattn"
- pos_enc_class = LegacyRelPositionalEncoding
- logging.warning(
- "Using legacy_rel_pos and it will be deprecated in the future."
- )
- else:
- raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
- if input_layer == "linear":
- self.embed = torch.nn.Sequential(
- torch.nn.Linear(input_size, output_size),
- torch.nn.LayerNorm(output_size),
- torch.nn.Dropout(dropout_rate),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d":
- self.embed = Conv2dSubsampling(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d6":
- self.embed = Conv2dSubsampling6(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "conv2d8":
- self.embed = Conv2dSubsampling8(
- input_size,
- output_size,
- dropout_rate,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer == "embed":
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif isinstance(input_layer, torch.nn.Module):
- self.embed = torch.nn.Sequential(
- input_layer,
- pos_enc_class(output_size, positional_dropout_rate),
- )
- elif input_layer is None:
- self.embed = torch.nn.Sequential(
- pos_enc_class(output_size, positional_dropout_rate)
- )
- else:
- raise ValueError("unknown input_layer: " + input_layer)
- self.normalize_before = normalize_before
- if positionwise_layer_type == "linear":
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- activation,
- )
- elif positionwise_layer_type == "conv1d":
- positionwise_layer = MultiLayeredConv1d
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- elif positionwise_layer_type == "conv1d-linear":
- positionwise_layer = Conv1dLinear
- positionwise_layer_args = (
- output_size,
- linear_units,
- positionwise_conv_kernel_size,
- dropout_rate,
- )
- else:
- raise NotImplementedError("Support only linear or conv1d.")
- if selfattention_layer_type == "selfattn":
- encoder_selfattn_layer = MultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- elif selfattention_layer_type == "legacy_rel_selfattn":
- assert pos_enc_layer_type == "legacy_rel_pos"
- encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- logging.warning(
- "Using legacy_rel_selfattn and it will be deprecated in the future."
- )
- elif selfattention_layer_type == "rel_selfattn":
- assert pos_enc_layer_type == "rel_pos"
- encoder_selfattn_layer = RelPositionMultiHeadedAttention
- encoder_selfattn_layer_args = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- zero_triu,
- )
- else:
- raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
- convolution_layer = ConvolutionModule
- convolution_layer_args = (output_size, cnn_module_kernel, activation)
- encoder_selfattn_layer_raw = MultiHeadedAttention
- encoder_selfattn_layer_args_raw = (
- attention_heads,
- output_size,
- attention_dropout_rate,
- )
- self.encoders = repeat(
- num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
- encoder_selfattn_layer(*encoder_selfattn_layer_args),
- positionwise_layer(*positionwise_layer_args),
- positionwise_layer(*positionwise_layer_args) if macaron_style else None,
- convolution_layer(*convolution_layer_args) if use_cnn_module else None,
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
- self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
- def output_size(self) -> int:
- return self._output_size
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- channel_size: torch.Tensor,
- prev_states: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Calculate forward propagation.
- Args:
- xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
- ilens (torch.Tensor): Input length (#batch).
- prev_states (torch.Tensor): Not to be used now.
- Returns:
- torch.Tensor: Output tensor (#batch, L, output_size).
- torch.Tensor: Output length (#batch).
- torch.Tensor: Not to be used now.
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
- xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
- if isinstance(xs_pad, tuple):
- xs_pad = xs_pad[0]
- t_leng = xs_pad.size(1)
- d_dim = xs_pad.size(2)
- xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
- # pdb.set_trace()
- if (channel_size < 8):
- repeat_num = math.ceil(8 / channel_size)
- xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
- xs_pad = self.conv1(xs_pad)
- xs_pad = self.conv2(xs_pad)
- xs_pad = self.conv3(xs_pad)
- xs_pad = self.conv4(xs_pad)
- xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
- mask_tmp = masks.size(1)
- masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
- olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
- def forward_hidden(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Calculate forward propagation.
- Args:
- xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
- ilens (torch.Tensor): Input length (#batch).
- prev_states (torch.Tensor): Not to be used now.
- Returns:
- torch.Tensor: Output tensor (#batch, L, output_size).
- torch.Tensor: Output length (#batch).
- torch.Tensor: Not to be used now.
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
- ):
- short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
- if short_status:
- raise TooShortUttError(
- f"has {xs_pad.size(1)} frames and is too short for subsampling "
- + f"(it needs more than {limit_size} frames), return empty results",
- xs_pad.size(1),
- limit_size,
- )
- xs_pad, masks = self.embed(xs_pad, masks)
- else:
- xs_pad = self.embed(xs_pad)
- num_layer = len(self.encoders)
- for idx, encoder in enumerate(self.encoders):
- xs_pad, masks = encoder(xs_pad, masks)
- if idx == num_layer // 2 - 1:
- hidden_feature = xs_pad
- if isinstance(xs_pad, tuple):
- xs_pad = xs_pad[0]
- hidden_feature = hidden_feature[0]
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
- self.hidden_feature = self.after_norm(hidden_feature)
- olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
|