model.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. # Modified from https://github.com/ddlBoJack/emotion2vec/tree/main
  6. import os
  7. import time
  8. import torch
  9. import logging
  10. import numpy as np
  11. from functools import partial
  12. from omegaconf import OmegaConf
  13. import torch.nn.functional as F
  14. from contextlib import contextmanager
  15. from distutils.version import LooseVersion
  16. from funasr.register import tables
  17. from funasr.models.emotion2vec.modules import AltBlock
  18. from funasr.models.emotion2vec.audio import AudioEncoder
  19. from funasr.utils.load_utils import load_audio_text_image_video
  20. logger = logging.getLogger(__name__)
  21. if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
  22. from torch.cuda.amp import autocast
  23. else:
  24. # Nothing to do if torch<1.6.0
  25. @contextmanager
  26. def autocast(enabled=True):
  27. yield
  28. @tables.register("model_classes", "Emotion2vec")
  29. class Emotion2vec(torch.nn.Module):
  30. """
  31. Author: Ziyang Ma, Zhisheng Zheng, Jiaxin Ye, Jinchao Li, Zhifu Gao, Shiliang Zhang, Xie Chen
  32. emotion2vec: Self-Supervised Pre-Training for Speech Emotion Representation
  33. https://arxiv.org/abs/2312.15185
  34. """
  35. def __init__(self, **kwargs):
  36. super().__init__()
  37. # import pdb; pdb.set_trace()
  38. cfg = OmegaConf.create(kwargs["model_conf"])
  39. self.cfg = cfg
  40. make_layer_norm = partial(
  41. torch.nn.LayerNorm, eps=cfg.get("norm_eps"), elementwise_affine=cfg.get("norm_affine")
  42. )
  43. def make_block(drop_path, dim=None, heads=None):
  44. return AltBlock(
  45. cfg.get("embed_dim") if dim is None else dim,
  46. cfg.get("num_heads") if heads is None else heads,
  47. cfg.get("mlp_ratio"),
  48. qkv_bias=True,
  49. drop=cfg.get("encoder_dropout"),
  50. attn_drop=cfg.get("attention_dropout"),
  51. mlp_drop=cfg.get("activation_dropout"),
  52. post_mlp_drop=cfg.get("post_mlp_drop"),
  53. drop_path=drop_path,
  54. norm_layer=make_layer_norm,
  55. layer_norm_first=cfg.get("layer_norm_first"),
  56. ffn_targets=not cfg.get("end_of_block_targets"),
  57. )
  58. self.alibi_biases = {}
  59. self.modality_encoders = torch.nn.ModuleDict()
  60. enc = AudioEncoder(
  61. cfg.modalities.audio,
  62. cfg.get("embed_dim"),
  63. make_block,
  64. make_layer_norm,
  65. cfg.get("layer_norm_first"),
  66. self.alibi_biases,
  67. )
  68. self.modality_encoders['AUDIO'] = enc
  69. self.ema = None
  70. self.average_top_k_layers = cfg.get("average_top_k_layers")
  71. self.loss_beta = cfg.get("loss_beta")
  72. self.loss_scale = cfg.get("loss_scale")
  73. self.dropout_input = torch.nn.Dropout(cfg.get("dropout_input"))
  74. dpr = np.linspace(cfg.get("start_drop_path_rate"), cfg.get("end_drop_path_rate"), cfg.get("depth"))
  75. self.blocks = torch.nn.ModuleList([make_block(dpr[i]) for i in range(cfg.get("depth"))])
  76. self.norm = None
  77. if cfg.get("layer_norm_first"):
  78. self.norm = make_layer_norm(cfg.get("embed_dim"))
  79. vocab_size = kwargs.get("vocab_size", -1)
  80. self.proj = None
  81. if vocab_size > 0:
  82. self.proj = torch.nn.Linear(cfg.get("embed_dim"), vocab_size)
  83. def forward(
  84. self,
  85. source,
  86. target=None,
  87. id=None,
  88. mode=None,
  89. padding_mask=None,
  90. mask=True,
  91. features_only=False,
  92. force_remove_masked=False,
  93. remove_extra_tokens=True,
  94. precomputed_mask=None,
  95. **kwargs,
  96. ):
  97. feature_extractor = self.modality_encoders['AUDIO']
  98. mask_seeds = None
  99. extractor_out = feature_extractor(
  100. source,
  101. padding_mask,
  102. mask,
  103. remove_masked=not features_only or force_remove_masked,
  104. clone_batch=self.cfg.get("clone_batch") if not features_only else 1,
  105. mask_seeds=mask_seeds,
  106. precomputed_mask=precomputed_mask,
  107. )
  108. x = extractor_out["x"]
  109. encoder_mask = extractor_out["encoder_mask"]
  110. masked_padding_mask = extractor_out["padding_mask"]
  111. masked_alibi_bias = extractor_out.get("alibi_bias", None)
  112. alibi_scale = extractor_out.get("alibi_scale", None)
  113. if self.dropout_input is not None:
  114. x = self.dropout_input(x)
  115. layer_results = []
  116. for i, blk in enumerate(self.blocks):
  117. if (
  118. not self.training
  119. or self.cfg.get("layerdrop", 0) == 0
  120. or (np.random.random() > self.cfg.get("layerdrop", 0))
  121. ):
  122. ab = masked_alibi_bias
  123. if ab is not None and alibi_scale is not None:
  124. scale = (
  125. alibi_scale[i]
  126. if alibi_scale.size(0) > 1
  127. else alibi_scale.squeeze(0)
  128. )
  129. ab = ab * scale.type_as(ab)
  130. x, lr = blk(
  131. x,
  132. padding_mask=masked_padding_mask,
  133. alibi_bias=ab,
  134. )
  135. if features_only:
  136. layer_results.append(lr)
  137. if self.norm is not None:
  138. x = self.norm(x)
  139. if features_only:
  140. if remove_extra_tokens:
  141. x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
  142. if masked_padding_mask is not None:
  143. masked_padding_mask = masked_padding_mask[
  144. :, feature_extractor.modality_cfg.num_extra_tokens :
  145. ]
  146. return {
  147. "x": x,
  148. "padding_mask": masked_padding_mask,
  149. "layer_results": layer_results,
  150. "mask": encoder_mask,
  151. }
  152. def extract_features(
  153. self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
  154. ):
  155. res = self.forward(
  156. source,
  157. mode=mode,
  158. padding_mask=padding_mask,
  159. mask=mask,
  160. features_only=True,
  161. remove_extra_tokens=remove_extra_tokens,
  162. )
  163. return res
  164. def inference(self,
  165. data_in,
  166. data_lengths=None,
  167. key: list = None,
  168. tokenizer=None,
  169. frontend=None,
  170. **kwargs,
  171. ):
  172. # if source_file.endswith('.wav'):
  173. # wav, sr = sf.read(source_file)
  174. # channel = sf.info(source_file).channels
  175. # assert sr == 16e3, "Sample rate should be 16kHz, but got {}in file {}".format(sr, source_file)
  176. # assert channel == 1, "Channel should be 1, but got {} in file {}".format(channel, source_file)
  177. granularity = kwargs.get("granularity", "utterance")
  178. extract_embedding = kwargs.get("extract_embedding", True)
  179. if self.proj is None:
  180. extract_embedding = True
  181. meta_data = {}
  182. # extract fbank feats
  183. time1 = time.perf_counter()
  184. audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000),
  185. data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
  186. time2 = time.perf_counter()
  187. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  188. meta_data["batch_data_time"] = len(audio_sample_list[0])/kwargs.get("fs", 16000)
  189. results = []
  190. output_dir = kwargs.get("output_dir")
  191. if output_dir:
  192. os.makedirs(output_dir, exist_ok=True)
  193. for i, wav in enumerate(audio_sample_list):
  194. source = wav.to(device=kwargs["device"])
  195. if self.cfg.normalize:
  196. source = F.layer_norm(source, source.shape)
  197. source = source.view(1, -1)
  198. feats = self.extract_features(source, padding_mask=None)
  199. x = feats['x']
  200. feats = feats['x'].squeeze(0).cpu().numpy()
  201. if granularity == 'frame':
  202. feats = feats
  203. elif granularity == 'utterance':
  204. feats = np.mean(feats, axis=0)
  205. if output_dir and extract_embedding:
  206. np.save(os.path.join(output_dir, "{}.npy".format(key[i])), feats)
  207. labels = tokenizer.token_list if tokenizer is not None else []
  208. scores = []
  209. if self.proj:
  210. x = x.mean(dim=1)
  211. x = self.proj(x)
  212. x = torch.softmax(x, dim=-1)
  213. scores = x[0].tolist()
  214. result_i = {"key": key[i], "labels": labels, "scores": scores}
  215. if extract_embedding:
  216. result_i["feats"] = feats
  217. results.append(result_i)
  218. return results, meta_data