utterance_mvn.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from typing import Tuple
  2. import torch
  3. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  4. from funasr.register import tables
  5. @tables.register("normalize_classes", "UtteranceMVN")
  6. class UtteranceMVN(torch.nn.Module):
  7. def __init__(
  8. self,
  9. norm_means: bool = True,
  10. norm_vars: bool = False,
  11. eps: float = 1.0e-20,
  12. ):
  13. super().__init__()
  14. self.norm_means = norm_means
  15. self.norm_vars = norm_vars
  16. self.eps = eps
  17. def extra_repr(self):
  18. return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  19. def forward(
  20. self, x: torch.Tensor, ilens: torch.Tensor = None
  21. ) -> Tuple[torch.Tensor, torch.Tensor]:
  22. """Forward function
  23. Args:
  24. x: (B, L, ...)
  25. ilens: (B,)
  26. """
  27. return utterance_mvn(
  28. x,
  29. ilens,
  30. norm_means=self.norm_means,
  31. norm_vars=self.norm_vars,
  32. eps=self.eps,
  33. )
  34. def utterance_mvn(
  35. x: torch.Tensor,
  36. ilens: torch.Tensor = None,
  37. norm_means: bool = True,
  38. norm_vars: bool = False,
  39. eps: float = 1.0e-20,
  40. ) -> Tuple[torch.Tensor, torch.Tensor]:
  41. """Apply utterance mean and variance normalization
  42. Args:
  43. x: (B, T, D), assumed zero padded
  44. ilens: (B,)
  45. norm_means:
  46. norm_vars:
  47. eps:
  48. """
  49. if ilens is None:
  50. ilens = x.new_full([x.size(0)], x.size(1))
  51. ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
  52. # Zero padding
  53. if x.requires_grad:
  54. x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
  55. else:
  56. x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
  57. # mean: (B, 1, D)
  58. mean = x.sum(dim=1, keepdim=True) / ilens_
  59. if norm_means:
  60. x -= mean
  61. if norm_vars:
  62. var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
  63. std = torch.clamp(var.sqrt(), min=eps)
  64. x = x / std.sqrt()
  65. return x, ilens
  66. else:
  67. if norm_vars:
  68. y = x - mean
  69. y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
  70. var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
  71. std = torch.clamp(var.sqrt(), min=eps)
  72. x /= std
  73. return x, ilens