audio.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import numpy as np
  7. import torch.nn as nn
  8. from functools import partial
  9. import torch.nn.functional as F
  10. from typing import Callable, Dict
  11. from funasr.models.emotion2vec.fairseq_modules import (
  12. LayerNorm,
  13. SamePad,
  14. TransposeLast,
  15. ConvFeatureExtractionModel,
  16. )
  17. from funasr.models.emotion2vec.modules import Modality, BlockEncoder, Decoder1d
  18. from funasr.models.emotion2vec.base import ModalitySpecificEncoder, get_alibi_bias
  19. class AudioEncoder(ModalitySpecificEncoder):
  20. def __init__(
  21. self,
  22. modality_cfg,
  23. embed_dim: int,
  24. make_block: Callable[[float], nn.ModuleList],
  25. norm_layer: Callable[[int], nn.LayerNorm],
  26. layer_norm_first: bool,
  27. alibi_biases: Dict,
  28. ):
  29. self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
  30. feature_embed_dim = self.feature_enc_layers[-1][0]
  31. local_encoder = ConvFeatureExtractionModel(
  32. conv_layers=self.feature_enc_layers,
  33. dropout=0.0,
  34. mode=modality_cfg.extractor_mode,
  35. conv_bias=False,
  36. )
  37. project_features = nn.Sequential(
  38. TransposeLast(),
  39. nn.LayerNorm(feature_embed_dim),
  40. nn.Linear(feature_embed_dim, embed_dim),
  41. )
  42. num_pos_layers = modality_cfg.conv_pos_depth
  43. k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
  44. positional_encoder = nn.Sequential(
  45. TransposeLast(),
  46. *[
  47. nn.Sequential(
  48. nn.Conv1d(
  49. embed_dim,
  50. embed_dim,
  51. kernel_size=k,
  52. padding=k // 2,
  53. groups=modality_cfg.conv_pos_groups,
  54. ),
  55. SamePad(k),
  56. TransposeLast(),
  57. LayerNorm(embed_dim, elementwise_affine=False),
  58. TransposeLast(),
  59. nn.GELU(),
  60. )
  61. for _ in range(num_pos_layers)
  62. ],
  63. TransposeLast(),
  64. )
  65. if modality_cfg.conv_pos_pre_ln:
  66. positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
  67. dpr = np.linspace(
  68. modality_cfg.start_drop_path_rate,
  69. modality_cfg.end_drop_path_rate,
  70. modality_cfg.prenet_depth,
  71. )
  72. context_encoder = BlockEncoder(
  73. nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
  74. norm_layer(embed_dim) if not layer_norm_first else None,
  75. layer_norm_first,
  76. modality_cfg.prenet_layerdrop,
  77. modality_cfg.prenet_dropout,
  78. )
  79. decoder = (
  80. Decoder1d(modality_cfg.decoder, embed_dim)
  81. if modality_cfg.decoder is not None
  82. else None
  83. )
  84. alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
  85. super().__init__(
  86. modality_cfg=modality_cfg,
  87. embed_dim=embed_dim,
  88. local_encoder=local_encoder,
  89. project_features=project_features,
  90. fixed_positional_encoder=None,
  91. relative_positional_encoder=positional_encoder,
  92. context_encoder=context_encoder,
  93. decoder=decoder,
  94. get_alibi_bias=alibi_bias_fn,
  95. )
  96. def convert_padding_mask(self, x, padding_mask):
  97. def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
  98. """
  99. Computes the output length of the convolutional layers
  100. """
  101. def _conv_out_length(input_length, kernel_size, stride):
  102. return torch.floor((input_length - kernel_size) / stride + 1)
  103. for i in range(len(self.feature_enc_layers)):
  104. input_lengths = _conv_out_length(
  105. input_lengths,
  106. self.feature_enc_layers[i][1],
  107. self.feature_enc_layers[i][2],
  108. )
  109. return input_lengths.to(torch.long)
  110. if padding_mask is not None:
  111. input_lengths = (1 - padding_mask.long()).sum(-1)
  112. # apply conv formula to get real output_lengths
  113. output_lengths = get_feat_extract_output_lengths(input_lengths)
  114. if padding_mask.any():
  115. padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
  116. # these two operations makes sure that all values
  117. # before the output lengths indices are attended to
  118. padding_mask[
  119. (
  120. torch.arange(padding_mask.shape[0], device=padding_mask.device),
  121. output_lengths - 1,
  122. )
  123. ] = 1
  124. padding_mask = (
  125. 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
  126. ).bool()
  127. else:
  128. padding_mask = torch.zeros(
  129. x.shape[:2], dtype=torch.bool, device=x.device
  130. )
  131. return padding_mask
  132. def reset_parameters(self):
  133. super().reset_parameters()
  134. for mod in self.project_features.children():
  135. if isinstance(mod, nn.Linear):
  136. mod.reset_parameters()
  137. if self.decoder is not None:
  138. self.decoder.reset_parameters()