| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- #!/usr/bin/env python3
- # -*- encoding: utf-8 -*-
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- # MIT License (https://opensource.org/licenses/MIT)
- # Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
- import os
- import time
- import torch
- import logging
- import numpy as np
- from functools import partial
- from omegaconf import OmegaConf
- import torch.nn.functional as F
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from funasr.register import tables
- from funasr.models.emotion2vec.modules import AltBlock
- from funasr.models.emotion2vec.audio import AudioEncoder
- from funasr.utils.load_utils import load_audio_text_image_video
- logger = logging.getLogger(__name__)
- if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
- @tables.register("model_classes", "Emotion2vec")
- class Emotion2vec(torch.nn.Module):
- """
- Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
- emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
- https://arxiv.org/abs/2312.15185
- """
- def __init__(self, **kwargs):
- super().__init__()
- # import pdb; pdb.set_trace()
- cfg = OmegaConf.create(kwargs["model_conf"])
- self.cfg = cfg
- make_layer_norm = partial(
- torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
- )
- def make_block(drop_path, dim=None, heads=None):
- return AltBlock(
- cfg.get("embed_dim") if dim is None else dim,
- cfg.get("num_heads") if heads is None else heads,
- cfg.get("mlp_ratio"),
- qkv_bias=True,
- drop=cfg.get("encoder_dropout"),
- attn_drop=cfg.get("attention_dropout"),
- mlp_drop=cfg.get("activation_dropout"),
- post_mlp_drop=cfg.get("post_mlp_drop"),
- drop_path=drop_path,
- norm_layer=make_layer_norm,
- layer_norm_first=cfg.get("layer_norm_first"),
- ffn_targets=not cfg.get("end_of_block_targets"),
- )
- self.alibi_biases = {}
- self.modality_encoders = torch.nn.ModuleDict()
- enc = AudioEncoder(
- cfg.modalities.audio,
- cfg.get("embed_dim"),
- make_block,
- make_layer_norm,
- cfg.get("layer_norm_first"),
- self.alibi_biases,
- )
- self.modality_encoders['AUDIO'] = enc
- self.ema = None
- self.average_top_k_layers = cfg.get("average_top_k_layers")
- self.loss_beta = cfg.get("loss_beta")
- self.loss_scale = cfg.get("loss_scale")
- self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
- dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
- self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
- self.norm = None
- if cfg.get("layer_norm_first"):
- self.norm = make_layer_norm(cfg.get("embed_dim"))
- vocab_size = kwargs.get("vocab_size", -1)
- self.proj = None
- if vocab_size > 0:
- self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
- def forward(
- self,
- source,
- target=None,
- id=None,
- mode=None,
- padding_mask=None,
- mask=True,
- features_only=False,
- force_remove_masked=False,
- remove_extra_tokens=True,
- precomputed_mask=None,
- **kwargs,
- ):
- feature_extractor = self.modality_encoders['AUDIO']
- mask_seeds = None
- extractor_out = feature_extractor(
- source,
- padding_mask,
- mask,
- remove_masked=not features_only or force_remove_masked,
- clone_batch=self.cfg.get("clone_batch") if not features_only else 1,
- mask_seeds=mask_seeds,
- precomputed_mask=precomputed_mask,
- )
- x = extractor_out["x"]
- encoder_mask = extractor_out["encoder_mask"]
- masked_padding_mask = extractor_out["padding_mask"]
- masked_alibi_bias = extractor_out.get("alibi_bias", None)
- alibi_scale = extractor_out.get("alibi_scale", None)
- if self.dropout_input is not None:
- x = self.dropout_input(x)
- layer_results = []
- for i, blk in enumerate(self.blocks):
- if (
- not self.training
- or self.cfg.get("layerdrop", 0) == 0
- or (np.random.random() > self.cfg.get("layerdrop", 0))
- ):
- ab = masked_alibi_bias
- if ab is not None and alibi_scale is not None:
- scale = (
- alibi_scale[i]
- if alibi_scale.size(0) > 1
- else alibi_scale.squeeze(0)
- )
- ab = ab * scale.type_as(ab)
- x, lr = blk(
- x,
- padding_mask=masked_padding_mask,
- alibi_bias=ab,
- )
- if features_only:
- layer_results.append(lr)
- if self.norm is not None:
- x = self.norm(x)
- if features_only:
- if remove_extra_tokens:
- x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
- if masked_padding_mask is not None:
- masked_padding_mask = masked_padding_mask[
- :, feature_extractor.modality_cfg.num_extra_tokens :
- ]
- return {
- "x": x,
- "padding_mask": masked_padding_mask,
- "layer_results": layer_results,
- "mask": encoder_mask,
- }
- def extract_features(
- self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
- ):
- res = self.forward(
- source,
- mode=mode,
- padding_mask=padding_mask,
- mask=mask,
- features_only=True,
- remove_extra_tokens=remove_extra_tokens,
- )
- return res
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
-
- # if source_file.endswith('.wav'):
- # wav, sr = sf.read(source_file)
- # channel = sf.info(source_file).channels
- # assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
- # assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
- granularity = kwargs.get("granularity", "utterance")
- extract_embedding = kwargs.get("extract_embedding", True)
- if self.proj is None:
- extract_embedding = True
- meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
-
- results = []
- output_dir = kwargs.get("output_dir")
- if output_dir:
- os.makedirs(output_dir, exist_ok=True)
- for i, wav in enumerate(audio_sample_list):
- source = wav.to(device=kwargs["device"])
- if self.cfg.normalize:
- source = F.layer_norm(source, source.shape)
- source = source.view(1, -1)
- feats = self.extract_features(source, padding_mask=None)
- x = feats['x']
- feats = feats['x'].squeeze(0).cpu().numpy()
- if granularity == 'frame':
- feats = feats
- elif granularity == 'utterance':
- feats = np.mean(feats, axis=0)
-
- if output_dir and extract_embedding:
- np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
- labels = tokenizer.token_list if tokenizer is not None else []
- scores = []
- if self.proj:
- x = x.mean(dim=1)
- x = self.proj(x)
- x = torch.softmax(x, dim=-1)
- scores = x[0].tolist()
- result_i = {"key": key[i], "labels": labels, "scores": scores}
- if extract_embedding:
- result_i["feats"] = feats
- results.append(result_i)
-
- return results, meta_data
|