| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import math
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- import torch.nn.functional as F
- from typeguard import check_argument_types
- from funasr.models.encoder.abs_encoder import AbsEncoder
- from funasr.modules.data2vec.data_utils import compute_mask_indices
- from funasr.modules.data2vec.ema_module import EMAModule
- from funasr.modules.data2vec.grad_multiply import GradMultiply
- from funasr.modules.data2vec.wav2vec2 import (
- ConvFeatureExtractionModel,
- TransformerEncoder,
- )
- from funasr.modules.nets_utils import make_pad_mask
- def get_annealed_rate(start, end, curr_step, total_steps):
- r = end - start
- pct_remaining = 1 - curr_step / total_steps
- return end - r * pct_remaining
- class Data2VecEncoder(AbsEncoder):
- def __init__(
- self,
- # for ConvFeatureExtractionModel
- input_size: int = None,
- extractor_mode: str = None,
- conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
- # for Transformer Encoder
- ## model architecture
- layer_type: str = "transformer",
- layer_norm_first: bool = False,
- encoder_layers: int = 12,
- encoder_embed_dim: int = 768,
- encoder_ffn_embed_dim: int = 3072,
- encoder_attention_heads: int = 12,
- activation_fn: str = "gelu",
- ## dropouts
- dropout: float = 0.1,
- attention_dropout: float = 0.1,
- activation_dropout: float = 0.0,
- encoder_layerdrop: float = 0.0,
- dropout_input: float = 0.0,
- dropout_features: float = 0.0,
- ## grad settings
- feature_grad_mult: float = 1.0,
- ## masking
- mask_prob: float = 0.65,
- mask_length: int = 10,
- mask_selection: str = "static",
- mask_other: int = 0,
- no_mask_overlap: bool = False,
- mask_min_space: int = 1,
- require_same_masks: bool = True, # if set as True, collate_fn should be clipping
- mask_dropout: float = 0.0,
- ## channel masking
- mask_channel_length: int = 10,
- mask_channel_prob: float = 0.0,
- mask_channel_before: bool = False,
- mask_channel_selection: str = "static",
- mask_channel_other: int = 0,
- no_mask_channel_overlap: bool = False,
- mask_channel_min_space: int = 1,
- ## positional embeddings
- conv_pos: int = 128,
- conv_pos_groups: int = 16,
- pos_conv_depth: int = 1,
- max_positions: int = 100000,
- # EMA module
- average_top_k_layers: int = 8,
- layer_norm_target_layer: bool = False,
- instance_norm_target_layer: bool = False,
- instance_norm_targets: bool = False,
- layer_norm_targets: bool = False,
- batch_norm_target_layer: bool = False,
- group_norm_target_layer: bool = False,
- ema_decay: float = 0.999,
- ema_end_decay: float = 0.9999,
- ema_anneal_end_step: int = 100000,
- ema_transformer_only: bool = True,
- ema_layers_only: bool = True,
- min_target_var: float = 0.1,
- min_pred_var: float = 0.01,
- # Loss
- loss_beta: float = 0.0,
- loss_scale: float = None,
- # FP16 optimization
- required_seq_len_multiple: int = 2,
- ):
- assert check_argument_types()
- super().__init__()
- # ConvFeatureExtractionModel
- self.conv_feature_layers = conv_feature_layers
- feature_enc_layers = eval(conv_feature_layers)
- self.extractor_embed = feature_enc_layers[-1][0]
- self.feature_extractor = ConvFeatureExtractionModel(
- conv_layers=feature_enc_layers,
- dropout=0.0,
- mode=extractor_mode,
- in_d=input_size,
- )
- # Transformer Encoder
- ## model architecture
- self.layer_type = layer_type
- self.layer_norm_first = layer_norm_first
- self.encoder_layers = encoder_layers
- self.encoder_embed_dim = encoder_embed_dim
- self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
- self.encoder_attention_heads = encoder_attention_heads
- self.activation_fn = activation_fn
- ## dropout
- self.dropout = dropout
- self.attention_dropout = attention_dropout
- self.activation_dropout = activation_dropout
- self.encoder_layerdrop = encoder_layerdrop
- self.dropout_input = dropout_input
- self.dropout_features = dropout_features
- ## grad settings
- self.feature_grad_mult = feature_grad_mult
- ## masking
- self.mask_prob = mask_prob
- self.mask_length = mask_length
- self.mask_selection = mask_selection
- self.mask_other = mask_other
- self.no_mask_overlap = no_mask_overlap
- self.mask_min_space = mask_min_space
- self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping
- self.mask_dropout = mask_dropout
- ## channel masking
- self.mask_channel_length = mask_channel_length
- self.mask_channel_prob = mask_channel_prob
- self.mask_channel_before = mask_channel_before
- self.mask_channel_selection = mask_channel_selection
- self.mask_channel_other = mask_channel_other
- self.no_mask_channel_overlap = no_mask_channel_overlap
- self.mask_channel_min_space = mask_channel_min_space
- ## positional embeddings
- self.conv_pos = conv_pos
- self.conv_pos_groups = conv_pos_groups
- self.pos_conv_depth = pos_conv_depth
- self.max_positions = max_positions
- self.mask_emb = nn.Parameter(torch.FloatTensor(self.encoder_embed_dim).uniform_())
- self.encoder = TransformerEncoder(
- dropout=self.dropout,
- encoder_embed_dim=self.encoder_embed_dim,
- required_seq_len_multiple=required_seq_len_multiple,
- pos_conv_depth=self.pos_conv_depth,
- conv_pos=self.conv_pos,
- conv_pos_groups=self.conv_pos_groups,
- # transformer layers
- layer_type=self.layer_type,
- encoder_layers=self.encoder_layers,
- encoder_ffn_embed_dim=self.encoder_ffn_embed_dim,
- encoder_attention_heads=self.encoder_attention_heads,
- attention_dropout=self.attention_dropout,
- activation_dropout=self.activation_dropout,
- activation_fn=self.activation_fn,
- layer_norm_first=self.layer_norm_first,
- encoder_layerdrop=self.encoder_layerdrop,
- max_positions=self.max_positions,
- )
- ## projections and dropouts
- self.post_extract_proj = nn.Linear(self.extractor_embed, self.encoder_embed_dim)
- self.dropout_input = nn.Dropout(self.dropout_input)
- self.dropout_features = nn.Dropout(self.dropout_features)
- self.layer_norm = torch.nn.LayerNorm(self.extractor_embed)
- self.final_proj = nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
- # EMA module
- self.average_top_k_layers = average_top_k_layers
- self.layer_norm_target_layer = layer_norm_target_layer
- self.instance_norm_target_layer = instance_norm_target_layer
- self.instance_norm_targets = instance_norm_targets
- self.layer_norm_targets = layer_norm_targets
- self.batch_norm_target_layer = batch_norm_target_layer
- self.group_norm_target_layer = group_norm_target_layer
- self.ema_decay = ema_decay
- self.ema_end_decay = ema_end_decay
- self.ema_anneal_end_step = ema_anneal_end_step
- self.ema_transformer_only = ema_transformer_only
- self.ema_layers_only = ema_layers_only
- self.min_target_var = min_target_var
- self.min_pred_var = min_pred_var
- self.ema = None
- # Loss
- self.loss_beta = loss_beta
- self.loss_scale = loss_scale
- # FP16 optimization
- self.required_seq_len_multiple = required_seq_len_multiple
- self.num_updates = 0
- logging.info("Data2VecEncoder settings: {}".format(self.__dict__))
- def make_ema_teacher(self):
- skip_keys = set()
- if self.ema_layers_only:
- self.ema_transformer_only = True
- for k, _ in self.encoder.pos_conv.named_parameters():
- skip_keys.add(f"pos_conv.{k}")
- self.ema = EMAModule(
- self.encoder if self.ema_transformer_only else self,
- ema_decay=self.ema_decay,
- ema_fp32=True,
- skip_keys=skip_keys,
- )
- def set_num_updates(self, num_updates):
- if self.ema is None and self.final_proj is not None:
- logging.info("Making EMA Teacher")
- self.make_ema_teacher()
- elif self.training and self.ema is not None:
- if self.ema_decay != self.ema_end_decay:
- if num_updates >= self.ema_anneal_end_step:
- decay = self.ema_end_decay
- else:
- decay = get_annealed_rate(
- self.ema_decay,
- self.ema_end_decay,
- num_updates,
- self.ema_anneal_end_step,
- )
- self.ema.set_decay(decay)
- if self.ema.get_decay() < 1:
- self.ema.step(self.encoder if self.ema_transformer_only else self)
- self.num_updates = num_updates
- def apply_mask(
- self,
- x,
- padding_mask,
- mask_indices=None,
- mask_channel_indices=None,
- ):
- B, T, C = x.shape
- if self.mask_channel_prob > 0 and self.mask_channel_before:
- mask_channel_indices = compute_mask_indices(
- (B, C),
- None,
- self.mask_channel_prob,
- self.mask_channel_length,
- self.mask_channel_selection,
- self.mask_channel_other,
- no_overlap=self.no_mask_channel_overlap,
- min_space=self.mask_channel_min_space,
- )
- mask_channel_indices = (
- torch.from_numpy(mask_channel_indices)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
- )
- x[mask_channel_indices] = 0
- if self.mask_prob > 0:
- if mask_indices is None:
- mask_indices = compute_mask_indices(
- (B, T),
- padding_mask,
- self.mask_prob,
- self.mask_length,
- self.mask_selection,
- self.mask_other,
- min_masks=1,
- no_overlap=self.no_mask_overlap,
- min_space=self.mask_min_space,
- require_same_masks=self.require_same_masks,
- mask_dropout=self.mask_dropout,
- )
- mask_indices = torch.from_numpy(mask_indices).to(x.device)
- x[mask_indices] = self.mask_emb
- else:
- mask_indices = None
- if self.mask_channel_prob > 0 and not self.mask_channel_before:
- if mask_channel_indices is None:
- mask_channel_indices = compute_mask_indices(
- (B, C),
- None,
- self.mask_channel_prob,
- self.mask_channel_length,
- self.mask_channel_selection,
- self.mask_channel_other,
- no_overlap=self.no_mask_channel_overlap,
- min_space=self.mask_channel_min_space,
- )
- mask_channel_indices = (
- torch.from_numpy(mask_channel_indices)
- .to(x.device)
- .unsqueeze(1)
- .expand(-1, T, -1)
- )
- x[mask_channel_indices] = 0
- return x, mask_indices
- def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
- """
- Computes the output length of the convolutional layers
- """
- def _conv_out_length(input_length, kernel_size, stride):
- return torch.floor((input_length - kernel_size).to(torch.float32) / stride + 1)
- conv_cfg_list = eval(self.conv_feature_layers)
- for i in range(len(conv_cfg_list)):
- input_lengths = _conv_out_length(
- input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
- )
- return input_lengths.to(torch.long)
- def forward(
- self,
- xs_pad,
- ilens=None,
- mask=False,
- features_only=True,
- layer=None,
- mask_indices=None,
- mask_channel_indices=None,
- padding_count=None,
- ):
- # create padding_mask by ilens
- if ilens is not None:
- padding_mask = make_pad_mask(lengths=ilens).to(xs_pad.device)
- else:
- padding_mask = None
- features = xs_pad
- if self.feature_grad_mult > 0:
- features = self.feature_extractor(features)
- if self.feature_grad_mult != 1.0:
- features = GradMultiply.apply(features, self.feature_grad_mult)
- else:
- with torch.no_grad():
- features = self.feature_extractor(features)
- features = features.transpose(1, 2)
- features = self.layer_norm(features)
- orig_padding_mask = padding_mask
- if padding_mask is not None:
- input_lengths = (1 - padding_mask.long()).sum(-1)
- # apply conv formula to get real output_lengths
- output_lengths = self._get_feat_extract_output_lengths(input_lengths)
- padding_mask = torch.zeros(
- features.shape[:2], dtype=features.dtype, device=features.device
- )
- # these two operations makes sure that all values
- # before the output lengths indices are attended to
- padding_mask[
- (
- torch.arange(padding_mask.shape[0], device=padding_mask.device),
- output_lengths - 1,
- )
- ] = 1
- padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
- else:
- padding_mask = None
- if self.post_extract_proj is not None:
- features = self.post_extract_proj(features)
- pre_encoder_features = None
- if self.ema_transformer_only:
- pre_encoder_features = features.clone()
- features = self.dropout_input(features)
- if mask:
- x, mask_indices = self.apply_mask(
- features,
- padding_mask,
- mask_indices=mask_indices,
- mask_channel_indices=mask_channel_indices,
- )
- else:
- x = features
- mask_indices = None
- x, layer_results = self.encoder(
- x,
- padding_mask=padding_mask,
- layer=layer,
- )
- if features_only:
- encoder_out_lens = (1 - padding_mask.long()).sum(1)
- return x, encoder_out_lens, None
- result = {
- "losses": {},
- "padding_mask": padding_mask,
- "x": x,
- }
- with torch.no_grad():
- self.ema.model.eval()
- if self.ema_transformer_only:
- y, layer_results = self.ema.model.extract_features(
- pre_encoder_features,
- padding_mask=padding_mask,
- min_layer=self.encoder_layers - self.average_top_k_layers,
- )
- y = {
- "x": y,
- "padding_mask": padding_mask,
- "layer_results": layer_results,
- }
- else:
- y = self.ema.model.extract_features(
- source=xs_pad,
- padding_mask=orig_padding_mask,
- mask=False,
- )
- target_layer_results = [l[2] for l in y["layer_results"]]
- permuted = False
- if self.instance_norm_target_layer or self.batch_norm_target_layer:
- target_layer_results = [
- tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
- ]
- permuted = True
- if self.batch_norm_target_layer:
- target_layer_results = [
- F.batch_norm(
- tl.float(), running_mean=None, running_var=None, training=True
- )
- for tl in target_layer_results
- ]
- if self.instance_norm_target_layer:
- target_layer_results = [
- F.instance_norm(tl.float()) for tl in target_layer_results
- ]
- if permuted:
- target_layer_results = [
- tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
- ]
- if self.group_norm_target_layer:
- target_layer_results = [
- F.layer_norm(tl.float(), tl.shape[-2:])
- for tl in target_layer_results
- ]
- if self.layer_norm_target_layer:
- target_layer_results = [
- F.layer_norm(tl.float(), tl.shape[-1:])
- for tl in target_layer_results
- ]
- y = sum(target_layer_results) / len(target_layer_results)
- if self.layer_norm_targets:
- y = F.layer_norm(y.float(), y.shape[-1:])
- if self.instance_norm_targets:
- y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
- if not permuted:
- y = y.transpose(0, 1)
- y = y[mask_indices]
- x = x[mask_indices]
- x = self.final_proj(x)
- sz = x.size(-1)
- if self.loss_beta == 0:
- loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
- else:
- loss = F.smooth_l1_loss(
- x.float(), y.float(), reduction="none", beta=self.loss_beta
- ).sum(dim=-1)
- if self.loss_scale is not None:
- scale = self.loss_scale
- else:
- scale = 1 / math.sqrt(sz)
- result["losses"]["regression"] = loss.sum() * scale
- if "sample_size" not in result:
- result["sample_size"] = loss.numel()
- with torch.no_grad():
- result["target_var"] = self.compute_var(y)
- result["pred_var"] = self.compute_var(x.float())
- if self.num_updates > 5000 and result["target_var"] < self.min_target_var:
- logging.error(
- f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
- )
- raise Exception(
- f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
- )
- if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
- logging.error(
- f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
- )
- raise Exception(
- f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
- )
- if self.ema is not None:
- result["ema_decay"] = self.ema.get_decay() * 1000
- return result
- @staticmethod
- def compute_var(y):
- y = y.view(-1, y.size(-1))
- if dist.is_initialized():
- zc = torch.tensor(y.size(0)).cuda()
- zs = y.sum(dim=0)
- zss = (y ** 2).sum(dim=0)
- dist.all_reduce(zc)
- dist.all_reduce(zs)
- dist.all_reduce(zss)
- var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
- return torch.sqrt(var + 1e-6).mean()
- else:
- return torch.sqrt(y.var(dim=0) + 1e-6).mean()
- def extract_features(
- self, xs_pad, ilens, mask=False, layer=None
- ):
- res = self.forward(
- xs_pad,
- ilens,
- mask=mask,
- features_only=True,
- layer=layer,
- )
- return res
- def remove_pretraining_modules(self, last_layer=None):
- self.final_proj = None
- self.ema = None
- if last_layer is not None:
- self.encoder.layers = nn.ModuleList(
- l for i, l in enumerate(self.encoder.layers) if i <= last_layer
- )
- def output_size(self) -> int:
- return self.encoder_embed_dim
|