| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- import numpy as np
- import torch.nn as nn
- from functools import partial
- import torch.nn.functional as F
- from typing import Callable, Dict
- from funasr.models.emotion2vec.fairseq_modules import (
- LayerNorm,
- SamePad,
- TransposeLast,
- ConvFeatureExtractionModel,
- )
- from funasr.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d
- from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
- class AudioEncoder(ModalitySpecificEncoder):
- def __init__(
- self,
- modality_cfg,
- embed_dim: int,
- make_block: Callable[[float], nn.ModuleList],
- norm_layer: Callable[[int], nn.LayerNorm],
- layer_norm_first: bool,
- alibi_biases: Dict,
- ):
- self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
- feature_embed_dim = self.feature_enc_layers[-1][0]
- local_encoder = ConvFeatureExtractionModel(
- conv_layers=self.feature_enc_layers,
- dropout=0.0,
- mode=modality_cfg.extractor_mode,
- conv_bias=False,
- )
- project_features = nn.Sequential(
- TransposeLast(),
- nn.LayerNorm(feature_embed_dim),
- nn.Linear(feature_embed_dim, embed_dim),
- )
- num_pos_layers = modality_cfg.conv_pos_depth
- k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
- positional_encoder = nn.Sequential(
- TransposeLast(),
- *[
- nn.Sequential(
- nn.Conv1d(
- embed_dim,
- embed_dim,
- kernel_size=k,
- padding=k // 2,
- groups=modality_cfg.conv_pos_groups,
- ),
- SamePad(k),
- TransposeLast(),
- LayerNorm(embed_dim, elementwise_affine=False),
- TransposeLast(),
- nn.GELU(),
- )
- for _ in range(num_pos_layers)
- ],
- TransposeLast(),
- )
- if modality_cfg.conv_pos_pre_ln:
- positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
- dpr = np.linspace(
- modality_cfg.start_drop_path_rate,
- modality_cfg.end_drop_path_rate,
- modality_cfg.prenet_depth,
- )
- context_encoder = BlockEncoder(
- nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
- norm_layer(embed_dim) if not layer_norm_first else None,
- layer_norm_first,
- modality_cfg.prenet_layerdrop,
- modality_cfg.prenet_dropout,
- )
- decoder = (
- Decoder1d(modality_cfg.decoder, embed_dim)
- if modality_cfg.decoder is not None
- else None
- )
- alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
- super().__init__(
- modality_cfg=modality_cfg,
- embed_dim=embed_dim,
- local_encoder=local_encoder,
- project_features=project_features,
- fixed_positional_encoder=None,
- relative_positional_encoder=positional_encoder,
- context_encoder=context_encoder,
- decoder=decoder,
- get_alibi_bias=alibi_bias_fn,
- )
- def convert_padding_mask(self, x, padding_mask):
- def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
- """
- Computes the output length of the convolutional layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- return torch.floor((input_length - kernel_size) / stride + 1)
- for i in range(len(self.feature_enc_layers)):
- input_lengths = _conv_out_length(
- input_lengths,
- self.feature_enc_layers[i][1],
- self.feature_enc_layers[i][2],
- )
- return input_lengths.to(torch.long)
- if padding_mask is not None:
- input_lengths = (1 - padding_mask.long()).sum(-1)
- # apply conv formula to get real output_lengths
- output_lengths = get_feat_extract_output_lengths(input_lengths)
- if padding_mask.any():
- padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
- # these two operations makes sure that all values
- # before the output lengths indices are attended to
- padding_mask[
- (
- torch.arange(padding_mask.shape[0], device=padding_mask.device),
- output_lengths - 1,
- )
- ] = 1
- padding_mask = (
- 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
- ).bool()
- else:
- padding_mask = torch.zeros(
- x.shape[:2], dtype=torch.bool, device=x.device
- )
- return padding_mask
- def reset_parameters(self):
- super().reset_parameters()
- for mod in self.project_features.children():
- if isinstance(mod, nn.Linear):
- mod.reset_parameters()
- if self.decoder is not None:
- self.decoder.reset_parameters()
|