| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- from pathlib import Path
- from typing import Tuple
- from typing import Union
- import numpy as np
- import torch
- from typeguard import check_argument_types
- from funasr.modules.nets_utils import make_pad_mask
- from funasr.layers.abs_normalize import AbsNormalize
- from funasr.layers.inversible_interface import InversibleInterface
- class GlobalMVN(AbsNormalize, InversibleInterface):
- """Apply global mean and variance normalization
- TODO(kamo): Make this class portable somehow
- Args:
- stats_file: npy file
- norm_means: Apply mean normalization
- norm_vars: Apply var normalization
- eps:
- """
- def __init__(
- self,
- stats_file: Union[Path, str],
- norm_means: bool = True,
- norm_vars: bool = True,
- eps: float = 1.0e-20,
- ):
- assert check_argument_types()
- super().__init__()
- self.norm_means = norm_means
- self.norm_vars = norm_vars
- self.eps = eps
- stats_file = Path(stats_file)
- self.stats_file = stats_file
- stats = np.load(stats_file)
- if isinstance(stats, np.ndarray):
- # Kaldi like stats
- count = stats[0].flatten()[-1]
- mean = stats[0, :-1] / count
- var = stats[1, :-1] / count - mean * mean
- else:
- # New style: Npz file
- count = stats["count"]
- sum_v = stats["sum"]
- sum_square_v = stats["sum_square"]
- mean = sum_v / count
- var = sum_square_v / count - mean * mean
- std = np.sqrt(np.maximum(var, eps))
- self.register_buffer("mean", torch.from_numpy(mean))
- self.register_buffer("std", torch.from_numpy(std))
- def extra_repr(self):
- return (
- f"stats_file={self.stats_file}, "
- f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
- )
- def forward(
- self, x: torch.Tensor, ilens: torch.Tensor = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Forward function
- Args:
- x: (B, L, ...)
- ilens: (B,)
- """
- if ilens is None:
- ilens = x.new_full([x.size(0)], x.size(1))
- norm_means = self.norm_means
- norm_vars = self.norm_vars
- self.mean = self.mean.to(x.device, x.dtype)
- self.std = self.std.to(x.device, x.dtype)
- mask = make_pad_mask(ilens, x, 1)
- # feat: (B, T, D)
- if norm_means:
- if x.requires_grad:
- x = x - self.mean
- else:
- x -= self.mean
- if x.requires_grad:
- x = x.masked_fill(mask, 0.0)
- else:
- x.masked_fill_(mask, 0.0)
- if norm_vars:
- x /= self.std
- return x, ilens
- def inverse(
- self, x: torch.Tensor, ilens: torch.Tensor = None
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if ilens is None:
- ilens = x.new_full([x.size(0)], x.size(1))
- norm_means = self.norm_means
- norm_vars = self.norm_vars
- self.mean = self.mean.to(x.device, x.dtype)
- self.std = self.std.to(x.device, x.dtype)
- mask = make_pad_mask(ilens, x, 1)
- if x.requires_grad:
- x = x.masked_fill(mask, 0.0)
- else:
- x.masked_fill_(mask, 0.0)
- if norm_vars:
- x *= self.std
- # feat: (B, T, D)
- if norm_means:
- x += self.mean
- x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
- return x, ilens
|