utterance_mvn.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import Tuple
  2. import torch
  3. from typeguard import check_argument_types
  4. from funasr.modules.nets_utils import make_pad_mask
  5. from funasr.layers.abs_normalize import AbsNormalize
  6. class UtteranceMVN(AbsNormalize):
  7. def __init__(
  8. self,
  9. norm_means: bool = True,
  10. norm_vars: bool = False,
  11. eps: float = 1.0e-20,
  12. ):
  13. assert check_argument_types()
  14. super().__init__()
  15. self.norm_means = norm_means
  16. self.norm_vars = norm_vars
  17. self.eps = eps
  18. def extra_repr(self):
  19. return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  20. def forward(
  21. self, x: torch.Tensor, ilens: torch.Tensor = None
  22. ) -> Tuple[torch.Tensor, torch.Tensor]:
  23. """Forward function
  24. Args:
  25. x: (B, L, ...)
  26. ilens: (B,)
  27. """
  28. return utterance_mvn(
  29. x,
  30. ilens,
  31. norm_means=self.norm_means,
  32. norm_vars=self.norm_vars,
  33. eps=self.eps,
  34. )
  35. def utterance_mvn(
  36. x: torch.Tensor,
  37. ilens: torch.Tensor = None,
  38. norm_means: bool = True,
  39. norm_vars: bool = False,
  40. eps: float = 1.0e-20,
  41. ) -> Tuple[torch.Tensor, torch.Tensor]:
  42. """Apply utterance mean and variance normalization
  43. Args:
  44. x: (B, T, D), assumed zero padded
  45. ilens: (B,)
  46. norm_means:
  47. norm_vars:
  48. eps:
  49. """
  50. if ilens is None:
  51. ilens = x.new_full([x.size(0)], x.size(1))
  52. ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
  53. # Zero padding
  54. if x.requires_grad:
  55. x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
  56. else:
  57. x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
  58. # mean: (B, 1, D)
  59. mean = x.sum(dim=1, keepdim=True) / ilens_
  60. if norm_means:
  61. x -= mean
  62. if norm_vars:
  63. var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
  64. std = torch.clamp(var.sqrt(), min=eps)
  65. x = x / std.sqrt()
  66. return x, ilens
  67. else:
  68. if norm_vars:
  69. y = x - mean
  70. y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
  71. var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
  72. std = torch.clamp(var.sqrt(), min=eps)
  73. x /= std
  74. return x, ilens