| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- try:
- from rotary_embedding_torch import RotaryEmbedding
- except:
- print("Please install rotary_embedding_torch by: \n pip install -U rotary_embedding_torch")
- from funasr.modules.layer_norm import GlobalLayerNorm, CumulativeLayerNorm, ScaleNorm
- from funasr.modules.embedding import ScaledSinuEmbedding
- from funasr.modules.mossformer import FLASH_ShareA_FFConvM
- def select_norm(norm, dim, shape):
- """Just a wrapper to select the normalization type.
- """
- if norm == "gln":
- return GlobalLayerNorm(dim, shape, elementwise_affine=True)
- if norm == "cln":
- return CumulativeLayerNorm(dim, elementwise_affine=True)
- if norm == "ln":
- return nn.GroupNorm(1, dim, eps=1e-8)
- else:
- return nn.BatchNorm1d(dim)
- class MossformerBlock(nn.Module):
- def __init__(
- self,
- *,
- dim,
- depth,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 4.,
- causal = False,
- attn_dropout = 0.1,
- norm_type = 'scalenorm',
- shift_tokens = True
- ):
- super().__init__()
- assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
- if norm_type == 'scalenorm':
- norm_klass = ScaleNorm
- elif norm_type == 'layernorm':
- norm_klass = nn.LayerNorm
- self.group_size = group_size
- rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
- # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
- self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
- def forward(
- self,
- x,
- *,
- mask = None
- ):
- ii = 0
- for flash in self.layers:
- x = flash(x, mask = mask)
- ii = ii + 1
- return x
- class MossFormer_MaskNet(nn.Module):
- """The MossFormer module for computing output masks.
- Arguments
- ---------
- in_channels : int
- Number of channels at the output of the encoder.
- out_channels : int
- Number of channels that would be inputted to the intra and inter blocks.
- num_blocks : int
- Number of layers of Dual Computation Block.
- norm : str
- Normalization type.
- num_spks : int
- Number of sources (speakers).
- skip_around_intra : bool
- Skip connection around intra.
- use_global_pos_enc : bool
- Global positional encodings.
- max_length : int
- Maximum sequence length.
- Example
- ---------
- >>> mossformer_block = MossFormerM(1, 64, 8)
- >>> mossformer_masknet = MossFormer_MaskNet(64, 64, intra_block, num_spks=2)
- >>> x = torch.randn(10, 64, 2000)
- >>> x = mossformer_masknet(x)
- >>> x.shape
- torch.Size([2, 10, 64, 2000])
- """
- def __init__(
- self,
- in_channels,
- out_channels,
- num_blocks=24,
- norm="ln",
- num_spks=2,
- skip_around_intra=True,
- use_global_pos_enc=True,
- max_length=20000,
- ):
- super(MossFormer_MaskNet, self).__init__()
- self.num_spks = num_spks
- self.num_blocks = num_blocks
- self.norm = select_norm(norm, in_channels, 3)
- self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
- self.use_global_pos_enc = use_global_pos_enc
- if self.use_global_pos_enc:
- self.pos_enc = ScaledSinuEmbedding(out_channels)
- self.mdl = Computation_Block(
- num_blocks,
- out_channels,
- norm,
- skip_around_intra=skip_around_intra,
- )
- self.conv1d_out = nn.Conv1d(
- out_channels, out_channels * num_spks, kernel_size=1
- )
- self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
- self.prelu = nn.PReLU()
- self.activation = nn.ReLU()
- # gated output layer
- self.output = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
- )
- self.output_gate = nn.Sequential(
- nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
- )
- def forward(self, x):
- """Returns the output tensor.
- Arguments
- ---------
- x : torch.Tensor
- Input tensor of dimension [B, N, S].
- Returns
- -------
- out : torch.Tensor
- Output tensor of dimension [spks, B, N, S]
- where, spks = Number of speakers
- B = Batchsize,
- N = number of filters
- S = the number of time frames
- """
- # before each line we indicate the shape after executing the line
- # [B, N, L]
- x = self.norm(x)
- # [B, N, L]
- x = self.conv1d_encoder(x)
- if self.use_global_pos_enc:
- #x = self.pos_enc(x.transpose(1, -1)).transpose(1, -1) + x * (
- # x.size(1) ** 0.5)
- base = x
- x = x.transpose(1, -1)
- emb = self.pos_enc(x)
- emb = emb.transpose(0, -1)
- #print('base: {}, emb: {}'.format(base.shape, emb.shape))
- x = base + emb
-
- # [B, N, S]
- #for i in range(self.num_modules):
- # x = self.dual_mdl[i](x)
- x = self.mdl(x)
- x = self.prelu(x)
- # [B, N*spks, S]
- x = self.conv1d_out(x)
- B, _, S = x.shape
- # [B*spks, N, S]
- x = x.view(B * self.num_spks, -1, S)
- # [B*spks, N, S]
- x = self.output(x) * self.output_gate(x)
- # [B*spks, N, S]
- x = self.conv1_decoder(x)
- # [B, spks, N, S]
- _, N, L = x.shape
- x = x.view(B, self.num_spks, N, L)
- x = self.activation(x)
- # [spks, B, N, S]
- x = x.transpose(0, 1)
- return x
- class MossFormerEncoder(nn.Module):
- """Convolutional Encoder Layer.
- Arguments
- ---------
- kernel_size : int
- Length of filters.
- in_channels : int
- Number of input channels.
- out_channels : int
- Number of output channels.
- Example
- -------
- >>> x = torch.randn(2, 1000)
- >>> encoder = Encoder(kernel_size=4, out_channels=64)
- >>> h = encoder(x)
- >>> h.shape
- torch.Size([2, 64, 499])
- """
- def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
- super(MossFormerEncoder, self).__init__()
- self.conv1d = nn.Conv1d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=kernel_size // 2,
- groups=1,
- bias=False,
- )
- self.in_channels = in_channels
- def forward(self, x):
- """Return the encoded output.
- Arguments
- ---------
- x : torch.Tensor
- Input tensor with dimensionality [B, L].
- Return
- ------
- x : torch.Tensor
- Encoded tensor with dimensionality [B, N, T_out].
- where B = Batchsize
- L = Number of timepoints
- N = Number of filters
- T_out = Number of timepoints at the output of the encoder
- """
- # B x L -> B x 1 x L
- if self.in_channels == 1:
- x = torch.unsqueeze(x, dim=1)
- # B x 1 x L -> B x N x T_out
- x = self.conv1d(x)
- x = F.relu(x)
- return x
- class MossFormerM(nn.Module):
- """This class implements the transformer encoder.
- Arguments
- ---------
- num_blocks : int
- Number of mossformer blocks to include.
- d_model : int
- The dimension of the input embedding.
- attn_dropout : float
- Dropout for the self-attention (Optional).
- group_size: int
- the chunk size
- query_key_dim: int
- the attention vector dimension
- expansion_factor: int
- the expansion factor for the linear projection in conv module
- causal: bool
- true for causal / false for non causal
- Example
- -------
- >>> import torch
- >>> x = torch.rand((8, 60, 512))
- >>> net = TransformerEncoder_MossFormerM(num_blocks=8, d_model=512)
- >>> output, _ = net(x)
- >>> output.shape
- torch.Size([8, 60, 512])
- """
- def __init__(
- self,
- num_blocks,
- d_model=None,
- causal=False,
- group_size = 256,
- query_key_dim = 128,
- expansion_factor = 4.,
- attn_dropout = 0.1
- ):
- super().__init__()
- self.mossformerM = MossformerBlock(
- dim=d_model,
- depth=num_blocks,
- group_size=group_size,
- query_key_dim=query_key_dim,
- expansion_factor=expansion_factor,
- causal=causal,
- attn_dropout=attn_dropout
- )
- self.norm = nn.LayerNorm(d_model, eps=1e-6)
- def forward(
- self,
- src,
- ):
- """
- Arguments
- ----------
- src : torch.Tensor
- Tensor shape [B, L, N],
- where, B = Batchsize,
- L = time points
- N = number of filters
- The sequence to the encoder layer (required).
- src_mask : tensor
- The mask for the src sequence (optional).
- src_key_padding_mask : tensor
- The mask for the src keys per batch (optional).
- """
- output = self.mossformerM(src)
- output = self.norm(output)
- return output
- class Computation_Block(nn.Module):
- """Computation block for dual-path processing.
- Arguments
- ---------
- out_channels : int
- Dimensionality of inter/intra model.
- norm : str
- Normalization type.
- skip_around_intra : bool
- Skip connection around the intra layer.
- Example
- ---------
- >>> comp_block = Computation_Block(64)
- >>> x = torch.randn(10, 64, 100)
- >>> x = comp_block(x)
- >>> x.shape
- torch.Size([10, 64, 100])
- """
- def __init__(
- self,
- num_blocks,
- out_channels,
- norm="ln",
- skip_around_intra=True,
- ):
- super(Computation_Block, self).__init__()
- ##MossFormer2M: MossFormer with recurrence
- #self.intra_mdl = MossFormer2M(num_blocks=num_blocks, d_model=out_channels)
- ##MossFormerM: the orignal MossFormer
- self.intra_mdl = MossFormerM(num_blocks=num_blocks, d_model=out_channels)
- self.skip_around_intra = skip_around_intra
- # Norm
- self.norm = norm
- if norm is not None:
- self.intra_norm = select_norm(norm, out_channels, 3)
- def forward(self, x):
- """Returns the output tensor.
- Arguments
- ---------
- x : torch.Tensor
- Input tensor of dimension [B, N, S].
- Return
- ---------
- out: torch.Tensor
- Output tensor of dimension [B, N, S].
- where, B = Batchsize,
- N = number of filters
- S = sequence time index
- """
- B, N, S = x.shape
- # intra RNN
- # [B, S, N]
- intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
- intra = self.intra_mdl(intra)
- # [B, N, S]
- intra = intra.permute(0, 2, 1).contiguous()
- if self.norm is not None:
- intra = self.intra_norm(intra)
- # [B, N, S]
- if self.skip_around_intra:
- intra = intra + x
- out = intra
- return out
|