model.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
  6. import logging
  7. import os
  8. from functools import partial
  9. import numpy as np
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from funasr.models.emotion2vec.modules import AltBlock
  14. from funasr.models.emotion2vec.audio import AudioEncoder
  15. from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
  16. from omegaconf import OmegaConf
  17. import time
  18. logger = logging.getLogger(__name__)
  19. from funasr.register import tables
  20. @tables.register("model_classes", "Emotion2vec")
  21. class Emotion2vec(nn.Module):
  22. """
  23. Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
  24. emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
  25. https://arxiv.org/abs/2312.15185
  26. """
  27. def __init__(self, **kwargs):
  28. super().__init__()
  29. # import pdb; pdb.set_trace()
  30. cfg = OmegaConf.create(kwargs["model_conf"])
  31. self.cfg = cfg
  32. make_layer_norm = partial(
  33. nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
  34. )
  35. def make_block(drop_path, dim=None, heads=None):
  36. return AltBlock(
  37. cfg.get("embed_dim") if dim is None else dim,
  38. cfg.get("num_heads") if heads is None else heads,
  39. cfg.get("mlp_ratio"),
  40. qkv_bias=True,
  41. drop=cfg.get("encoder_dropout"),
  42. attn_drop=cfg.get("attention_dropout"),
  43. mlp_drop=cfg.get("activation_dropout"),
  44. post_mlp_drop=cfg.get("post_mlp_drop"),
  45. drop_path=drop_path,
  46. norm_layer=make_layer_norm,
  47. layer_norm_first=cfg.get("layer_norm_first"),
  48. ffn_targets=not cfg.get("end_of_block_targets"),
  49. )
  50. self.alibi_biases = {}
  51. self.modality_encoders = nn.ModuleDict()
  52. enc = AudioEncoder(
  53. cfg.modalities.audio,
  54. cfg.get("embed_dim"),
  55. make_block,
  56. make_layer_norm,
  57. cfg.get("layer_norm_first"),
  58. self.alibi_biases,
  59. )
  60. self.modality_encoders['AUDIO'] = enc
  61. self.ema = None
  62. self.average_top_k_layers = cfg.get("average_top_k_layers")
  63. self.loss_beta = cfg.get("loss_beta")
  64. self.loss_scale = cfg.get("loss_scale")
  65. self.dropout_input = nn.Dropout(cfg.get("dropout_input"))
  66. dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
  67. self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
  68. self.norm = None
  69. if cfg.get("layer_norm_first"):
  70. self.norm = make_layer_norm(cfg.get("embed_dim"))
  71. def forward(
  72. self,
  73. source,
  74. target=None,
  75. id=None,
  76. mode=None,
  77. padding_mask=None,
  78. mask=True,
  79. features_only=False,
  80. force_remove_masked=False,
  81. remove_extra_tokens=True,
  82. precomputed_mask=None,
  83. **kwargs,
  84. ):
  85. feature_extractor = self.modality_encoders['AUDIO']
  86. mask_seeds = None
  87. extractor_out = feature_extractor(
  88. source,
  89. padding_mask,
  90. mask,
  91. remove_masked=not features_only or force_remove_masked,
  92. clone_batch=self.cfg.get("clone_batch") if not features_only else 1,
  93. mask_seeds=mask_seeds,
  94. precomputed_mask=precomputed_mask,
  95. )
  96. x = extractor_out["x"]
  97. encoder_mask = extractor_out["encoder_mask"]
  98. masked_padding_mask = extractor_out["padding_mask"]
  99. masked_alibi_bias = extractor_out.get("alibi_bias", None)
  100. alibi_scale = extractor_out.get("alibi_scale", None)
  101. if self.dropout_input is not None:
  102. x = self.dropout_input(x)
  103. layer_results = []
  104. for i, blk in enumerate(self.blocks):
  105. if (
  106. not self.training
  107. or self.cfg.get("layerdrop", 0) == 0
  108. or (np.random.random() > self.cfg.get("layerdrop", 0))
  109. ):
  110. ab = masked_alibi_bias
  111. if ab is not None and alibi_scale is not None:
  112. scale = (
  113. alibi_scale[i]
  114. if alibi_scale.size(0) > 1
  115. else alibi_scale.squeeze(0)
  116. )
  117. ab = ab * scale.type_as(ab)
  118. x, lr = blk(
  119. x,
  120. padding_mask=masked_padding_mask,
  121. alibi_bias=ab,
  122. )
  123. if features_only:
  124. layer_results.append(lr)
  125. if self.norm is not None:
  126. x = self.norm(x)
  127. if features_only:
  128. if remove_extra_tokens:
  129. x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
  130. if masked_padding_mask is not None:
  131. masked_padding_mask = masked_padding_mask[
  132. :, feature_extractor.modality_cfg.num_extra_tokens :
  133. ]
  134. return {
  135. "x": x,
  136. "padding_mask": masked_padding_mask,
  137. "layer_results": layer_results,
  138. "mask": encoder_mask,
  139. }
  140. def extract_features(
  141. self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
  142. ):
  143. res = self.forward(
  144. source,
  145. mode=mode,
  146. padding_mask=padding_mask,
  147. mask=mask,
  148. features_only=True,
  149. remove_extra_tokens=remove_extra_tokens,
  150. )
  151. return res
  152. def generate(self,
  153. data_in,
  154. data_lengths=None,
  155. key: list = None,
  156. tokenizer=None,
  157. frontend=None,
  158. **kwargs,
  159. ):
  160. # if source_file.endswith('.wav'):
  161. # wav, sr = sf.read(source_file)
  162. # channel = sf.info(source_file).channels
  163. # assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
  164. # assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
  165. granularity = kwargs.get("granularity", "utterance")
  166. meta_data = {}
  167. # extract fbank feats
  168. time1 = time.perf_counter()
  169. audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
  170. data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
  171. time2 = time.perf_counter()
  172. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  173. results = []
  174. output_dir = kwargs.get("output_dir")
  175. if output_dir:
  176. os.makedirs(output_dir, exist_ok=True)
  177. for i, wav in enumerate(audio_sample_list):
  178. source = wav.to(device=kwargs["device"])
  179. if self.cfg.normalize:
  180. source = F.layer_norm(source, source.shape)
  181. source = source.view(1, -1)
  182. feats = self.extract_features(source, padding_mask=None)
  183. feats = feats['x'].squeeze(0).cpu().numpy()
  184. if granularity == 'frame':
  185. feats = feats
  186. elif granularity == 'utterance':
  187. feats = np.mean(feats, axis=0)
  188. result_i = {"key": key[i], "feats": feats}
  189. results.append(result_i)
  190. if output_dir:
  191. np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
  192. return results, meta_data