global_mvn.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from pathlib import Path
  2. from typing import Tuple
  3. from typing import Union
  4. import numpy as np
  5. import torch
  6. from typeguard import check_argument_types
  7. from funasr.modules.nets_utils import make_pad_mask
  8. from funasr.layers.abs_normalize import AbsNormalize
  9. from funasr.layers.inversible_interface import InversibleInterface
  10. class GlobalMVN(AbsNormalize, InversibleInterface):
  11. """Apply global mean and variance normalization
  12. TODO(kamo): Make this class portable somehow
  13. Args:
  14. stats_file: npy file
  15. norm_means: Apply mean normalization
  16. norm_vars: Apply var normalization
  17. eps:
  18. """
  19. def __init__(
  20. self,
  21. stats_file: Union[Path, str],
  22. norm_means: bool = True,
  23. norm_vars: bool = True,
  24. eps: float = 1.0e-20,
  25. ):
  26. assert check_argument_types()
  27. super().__init__()
  28. self.norm_means = norm_means
  29. self.norm_vars = norm_vars
  30. self.eps = eps
  31. stats_file = Path(stats_file)
  32. self.stats_file = stats_file
  33. stats = np.load(stats_file)
  34. if isinstance(stats, np.ndarray):
  35. # Kaldi like stats
  36. count = stats[0].flatten()[-1]
  37. mean = stats[0, :-1] / count
  38. var = stats[1, :-1] / count - mean * mean
  39. else:
  40. # New style: Npz file
  41. count = stats["count"]
  42. sum_v = stats["sum"]
  43. sum_square_v = stats["sum_square"]
  44. mean = sum_v / count
  45. var = sum_square_v / count - mean * mean
  46. std = np.sqrt(np.maximum(var, eps))
  47. self.register_buffer("mean", torch.from_numpy(mean))
  48. self.register_buffer("std", torch.from_numpy(std))
  49. def extra_repr(self):
  50. return (
  51. f"stats_file={self.stats_file}, "
  52. f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
  53. )
  54. def forward(
  55. self, x: torch.Tensor, ilens: torch.Tensor = None
  56. ) -> Tuple[torch.Tensor, torch.Tensor]:
  57. """Forward function
  58. Args:
  59. x: (B, L, ...)
  60. ilens: (B,)
  61. """
  62. if ilens is None:
  63. ilens = x.new_full([x.size(0)], x.size(1))
  64. norm_means = self.norm_means
  65. norm_vars = self.norm_vars
  66. self.mean = self.mean.to(x.device, x.dtype)
  67. self.std = self.std.to(x.device, x.dtype)
  68. mask = make_pad_mask(ilens, x, 1)
  69. # feat: (B, T, D)
  70. if norm_means:
  71. if x.requires_grad:
  72. x = x - self.mean
  73. else:
  74. x -= self.mean
  75. if x.requires_grad:
  76. x = x.masked_fill(mask, 0.0)
  77. else:
  78. x.masked_fill_(mask, 0.0)
  79. if norm_vars:
  80. x /= self.std
  81. return x, ilens
  82. def inverse(
  83. self, x: torch.Tensor, ilens: torch.Tensor = None
  84. ) -> Tuple[torch.Tensor, torch.Tensor]:
  85. if ilens is None:
  86. ilens = x.new_full([x.size(0)], x.size(1))
  87. norm_means = self.norm_means
  88. norm_vars = self.norm_vars
  89. self.mean = self.mean.to(x.device, x.dtype)
  90. self.std = self.std.to(x.device, x.dtype)
  91. mask = make_pad_mask(ilens, x, 1)
  92. if x.requires_grad:
  93. x = x.masked_fill(mask, 0.0)
  94. else:
  95. x.masked_fill_(mask, 0.0)
  96. if norm_vars:
  97. x *= self.std
  98. # feat: (B, T, D)
  99. if norm_means:
  100. x += self.mean
  101. x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
  102. return x, ilens