global_mvn.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from pathlib import Path
  2. from typing import Tuple
  3. from typing import Union
  4. import numpy as np
  5. import torch
  6. from funasr.modules.nets_utils import make_pad_mask
  7. from funasr.layers.abs_normalize import AbsNormalize
  8. from funasr.layers.inversible_interface import InversibleInterface
  9. class GlobalMVN(AbsNormalize, InversibleInterface):
  10. """Apply global mean and variance normalization
  11. TODO(kamo): Make this class portable somehow
  12. Args:
  13. stats_file: npy file
  14. norm_means: Apply mean normalization
  15. norm_vars: Apply var normalization
  16. eps:
  17. """
  18. def __init__(
  19. self,
  20. stats_file: Union[Path, str],
  21. norm_means: bool = True,
  22. norm_vars: bool = True,
  23. eps: float = 1.0e-20,
  24. ):
  25. super().__init__()
  26. self.norm_means = norm_means
  27. self.norm_vars = norm_vars
  28. self.eps = eps
  29. stats_file = Path(stats_file)
  30. self.stats_file = stats_file
  31. stats = np.load(stats_file)
  32. if isinstance(stats, np.ndarray):
  33. # Kaldi like stats
  34. count = stats[0].flatten()[-1]
  35. mean = stats[0, :-1] / count
  36. var = stats[1, :-1] / count - mean * mean
  37. else:
  38. # New style: Npz file
  39. count = stats["count"]
  40. sum_v = stats["sum"]
  41. sum_square_v = stats["sum_square"]
  42. mean = sum_v / count
  43. var = sum_square_v / count - mean * mean
  44. std = np.sqrt(np.maximum(var, eps))
  45. self.register_buffer("mean", torch.from_numpy(mean))
  46. self.register_buffer("std", torch.from_numpy(std))
  47. def extra_repr(self):
  48. return (
  49. f"stats_file={self.stats_file}, "
  50. f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  51. )
  52. def forward(
  53. self, x: torch.Tensor, ilens: torch.Tensor = None
  54. ) -> Tuple[torch.Tensor, torch.Tensor]:
  55. """Forward function
  56. Args:
  57. x: (B, L, ...)
  58. ilens: (B,)
  59. """
  60. if ilens is None:
  61. ilens = x.new_full([x.size(0)], x.size(1))
  62. norm_means = self.norm_means
  63. norm_vars = self.norm_vars
  64. self.mean = self.mean.to(x.device, x.dtype)
  65. self.std = self.std.to(x.device, x.dtype)
  66. mask = make_pad_mask(ilens, x, 1)
  67. # feat: (B, T, D)
  68. if norm_means:
  69. if x.requires_grad:
  70. x = x - self.mean
  71. else:
  72. x -= self.mean
  73. if x.requires_grad:
  74. x = x.masked_fill(mask, 0.0)
  75. else:
  76. x.masked_fill_(mask, 0.0)
  77. if norm_vars:
  78. x /= self.std
  79. return x, ilens
  80. def inverse(
  81. self, x: torch.Tensor, ilens: torch.Tensor = None
  82. ) -> Tuple[torch.Tensor, torch.Tensor]:
  83. if ilens is None:
  84. ilens = x.new_full([x.size(0)], x.size(1))
  85. norm_means = self.norm_means
  86. norm_vars = self.norm_vars
  87. self.mean = self.mean.to(x.device, x.dtype)
  88. self.std = self.std.to(x.device, x.dtype)
  89. mask = make_pad_mask(ilens, x, 1)
  90. if x.requires_grad:
  91. x = x.masked_fill(mask, 0.0)
  92. else:
  93. x.masked_fill_(mask, 0.0)
  94. if norm_vars:
  95. x *= self.std
  96. # feat: (B, T, D)
  97. if norm_means:
  98. x += self.mean
  99. x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
  100. return x, ilens