utterance_mvn.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Tuple
  2. import torch
  3. from funasr.modules.nets_utils import make_pad_mask
  4. from funasr.layers.abs_normalize import AbsNormalize
  5. class UtteranceMVN(AbsNormalize):
  6. def __init__(
  7. self,
  8. norm_means: bool = True,
  9. norm_vars: bool = False,
  10. eps: float = 1.0e-20,
  11. ):
  12. super().__init__()
  13. self.norm_means = norm_means
  14. self.norm_vars = norm_vars
  15. self.eps = eps
  16. def extra_repr(self):
  17. return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  18. def forward(
  19. self, x: torch.Tensor, ilens: torch.Tensor = None
  20. ) -> Tuple[torch.Tensor, torch.Tensor]:
  21. """Forward function
  22. Args:
  23. x: (B, L, ...)
  24. ilens: (B,)
  25. """
  26. return utterance_mvn(
  27. x,
  28. ilens,
  29. norm_means=self.norm_means,
  30. norm_vars=self.norm_vars,
  31. eps=self.eps,
  32. )
  33. def utterance_mvn(
  34. x: torch.Tensor,
  35. ilens: torch.Tensor = None,
  36. norm_means: bool = True,
  37. norm_vars: bool = False,
  38. eps: float = 1.0e-20,
  39. ) -> Tuple[torch.Tensor, torch.Tensor]:
  40. """Apply utterance mean and variance normalization
  41. Args:
  42. x: (B, T, D), assumed zero padded
  43. ilens: (B,)
  44. norm_means:
  45. norm_vars:
  46. eps:
  47. """
  48. if ilens is None:
  49. ilens = x.new_full([x.size(0)], x.size(1))
  50. ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
  51. # Zero padding
  52. if x.requires_grad:
  53. x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
  54. else:
  55. x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
  56. # mean: (B, 1, D)
  57. mean = x.sum(dim=1, keepdim=True) / ilens_
  58. if norm_means:
  59. x -= mean
  60. if norm_vars:
  61. var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
  62. std = torch.clamp(var.sqrt(), min=eps)
  63. x = x / std.sqrt()
  64. return x, ilens
  65. else:
  66. if norm_vars:
  67. y = x - mean
  68. y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
  69. var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
  70. std = torch.clamp(var.sqrt(), min=eps)
  71. x /= std
  72. return x, ilens