utils.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from funasr.modules.data2vec.multihead_attention import MultiheadAttention
  10. class Fp32LayerNorm(nn.LayerNorm):
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. def forward(self, input):
  14. output = F.layer_norm(
  15. input.float(),
  16. self.normalized_shape,
  17. self.weight.float() if self.weight is not None else None,
  18. self.bias.float() if self.bias is not None else None,
  19. self.eps,
  20. )
  21. return output.type_as(input)
  22. class Fp32GroupNorm(nn.GroupNorm):
  23. def __init__(self, *args, **kwargs):
  24. super().__init__(*args, **kwargs)
  25. def forward(self, input):
  26. output = F.group_norm(
  27. input.float(),
  28. self.num_groups,
  29. self.weight.float() if self.weight is not None else None,
  30. self.bias.float() if self.bias is not None else None,
  31. self.eps,
  32. )
  33. return output.type_as(input)
  34. class TransposeLast(nn.Module):
  35. def __init__(self, deconstruct_idx=None):
  36. super().__init__()
  37. self.deconstruct_idx = deconstruct_idx
  38. def forward(self, x):
  39. if self.deconstruct_idx is not None:
  40. x = x[self.deconstruct_idx]
  41. return x.transpose(-2, -1)
  42. class SamePad(nn.Module):
  43. def __init__(self, kernel_size, causal=False):
  44. super().__init__()
  45. if causal:
  46. self.remove = kernel_size - 1
  47. else:
  48. self.remove = 1 if kernel_size % 2 == 0 else 0
  49. def forward(self, x):
  50. if self.remove > 0:
  51. x = x[:, :, : -self.remove]
  52. return x
  53. def pad_to_multiple(x, multiple, dim=-1, value=0):
  54. # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
  55. if x is None:
  56. return None, 0
  57. tsz = x.size(dim)
  58. m = tsz / multiple
  59. remainder = math.ceil(m) * multiple - tsz
  60. if m.is_integer():
  61. return x, 0
  62. pad_offset = (0,) * (-1 - dim) * 2
  63. return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
  64. def gelu_accurate(x):
  65. if not hasattr(gelu_accurate, "_a"):
  66. gelu_accurate._a = math.sqrt(2 / math.pi)
  67. return (
  68. 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
  69. )
  70. def gelu(x: torch.Tensor) -> torch.Tensor:
  71. return torch.nn.functional.gelu(x.float()).type_as(x)
  72. def get_available_activation_fns():
  73. return [
  74. "relu",
  75. "gelu",
  76. "gelu_fast", # deprecated
  77. "gelu_accurate",
  78. "tanh",
  79. "linear",
  80. ]
  81. def get_activation_fn(activation: str):
  82. """Returns the activation function corresponding to `activation`"""
  83. if activation == "relu":
  84. return F.relu
  85. elif activation == "gelu":
  86. return gelu
  87. elif activation == "gelu_accurate":
  88. return gelu_accurate
  89. elif activation == "tanh":
  90. return torch.tanh
  91. elif activation == "linear":
  92. return lambda x: x
  93. elif activation == "swish":
  94. return torch.nn.SiLU
  95. else:
  96. raise RuntimeError("--activation-fn {} not supported".format(activation))
  97. def init_bert_params(module):
  98. """
  99. Initialize the weights specific to the BERT Model.
  100. This overrides the default initializations depending on the specified arguments.
  101. 1. If normal_init_linear_weights is set then weights of linear
  102. layer will be initialized using the normal distribution and
  103. bais will be set to the specified value.
  104. 2. If normal_init_embed_weights is set then weights of embedding
  105. layer will be initialized using the normal distribution.
  106. 3. If normal_init_proj_weights is set then weights of
  107. in_project_weight for MultiHeadAttention initialized using
  108. the normal distribution (to be validated).
  109. """
  110. def normal_(data):
  111. # with FSDP, module params will be on CUDA, so we cast them back to CPU
  112. # so that the RNG is consistent with and without FSDP
  113. data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
  114. if isinstance(module, nn.Linear):
  115. normal_(module.weight.data)
  116. if module.bias is not None:
  117. module.bias.data.zero_()
  118. if isinstance(module, nn.Embedding):
  119. normal_(module.weight.data)
  120. if module.padding_idx is not None:
  121. module.weight.data[module.padding_idx].zero_()
  122. if isinstance(module, MultiheadAttention):
  123. normal_(module.q_proj.weight.data)
  124. normal_(module.k_proj.weight.data)
  125. normal_(module.v_proj.weight.data)