| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Transformer encoder definition."""
- from typing import List
- from typing import Optional
- from typing import Tuple
- import torch
- from torch import nn
- import logging
- from funasr.models.transformer.attention import MultiHeadedAttention
- from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
- from funasr.models.transformer.embedding import PositionalEncoding
- from funasr.models.transformer.layer_norm import LayerNorm
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- from funasr.models.transformer.utils.repeat import repeat
- from funasr.register import tables
- class EncoderLayer(nn.Module):
- """Encoder layer module.
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
- can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
- stochastic_depth_rate (float): Proability to skip this layer.
- During training, the layer may skip residual computation and return input
- as-is with given probability.
- """
- def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayer, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
- def forward(self, x, mask, cache=None):
- """Compute encoded features.
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
- if cache is None:
- x_q = x
- else:
- assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
- x_q = x[:, -1:, :]
- residual = residual[:, -1:, :]
- mask = None if mask is None else mask[:, -1:, :]
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x_q, x, x, mask)
- )
- if not self.normalize_before:
- x = self.norm1(x)
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
- @tables.register("encoder_classes", "TransformerTextEncoder")
- class TransformerTextEncoder(nn.Module):
- """Transformer text encoder module.
- Args:
- input_size: input dim
- output_size: dimension of attention
- attention_heads: the number of heads of multi head attention
- linear_units: the number of units of position-wise feed forward
- num_blocks: the number of decoder blocks
- dropout_rate: dropout rate
- attention_dropout_rate: dropout rate in attention
- positional_dropout_rate: dropout rate after adding positional encoding
- input_layer: input layer type
- pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
- normalize_before: whether to use layer_norm before the first block
- concat_after: whether to concat attention layer's input and output
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied.
- i.e. x -> x + att(x)
- positionwise_layer_type: linear of conv1d
- positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
- padding_idx: padding_idx for input_layer=embed
- """
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- ):
- super().__init__()
- self._output_size = output_size
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size),
- pos_enc_class(output_size, positional_dropout_rate),
- )
- self.normalize_before = normalize_before
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- self.encoders = repeat(
- num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- MultiHeadedAttention(
- attention_heads, output_size, attention_dropout_rate
- ),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
- def output_size(self) -> int:
- return self._output_size
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad = self.embed(xs_pad)
- xs_pad, masks = self.encoders(xs_pad, masks)
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
- olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
- @tables.register("encoder_classes", "FusionSANEncoder")
- class SelfSrcAttention(nn.Module):
- """Single decoder layer module.
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- src_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
- """
- def __init__(
- self,
- size,
- attention_heads,
- attention_dim,
- linear_units,
- self_attention_dropout_rate,
- src_attention_dropout_rate,
- positional_dropout_rate,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an SelfSrcAttention object."""
- super(SelfSrcAttention, self).__init__()
- self.size = size
- self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
- self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
- self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
- def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
- """Compute decoded features.
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
- """
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
- # compute only the last frame query keeping dim: max_time_out -> 1
- assert cache.shape == (
- tgt.shape[0],
- tgt.shape[1] - 1,
- self.size,
- ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
- if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
- x = residual + self.concat_linear1(tgt_concat)
- else:
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
- if not self.normalize_before:
- x = self.norm1(x)
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x, score = self.src_attn(x, memory, memory, memory_mask)
- x = residual + self.dropout(x)
- if not self.normalize_before:
- x = self.norm2(x)
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm3(x)
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, tgt_mask, memory, memory_mask
- @tables.register("encoder_classes", "ConvBiasPredictor")
- class ConvPredictor(nn.Module):
- def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
- super().__init__()
- self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
- self.norm1 = LayerNorm(size)
- self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
- self.norm2 = LayerNorm(size)
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
- self.output_linear = nn.Linear(size, 1)
- def forward(self, text_enc, asr_enc):
- # stage1 cross-attention
- residual = text_enc
- text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
-
- # stage2 FFN
- residual = text_enc
- text_enc = self.norm1(text_enc)
- text_enc = residual + self.feed_forward(text_enc)
-
- # stage Conv predictor
- text_enc = self.norm2(text_enc)
- context = text_enc.transpose(1, 2)
- queries = self.pad(context)
- memory = self.conv1d(queries)
- output = memory + context
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.output_linear(output)
- if output.dim()==3:
- output = output.squeeze(2)
- return output
|