|
|
@@ -7,12 +7,13 @@ import torch
|
|
|
from typeguard import check_argument_types
|
|
|
|
|
|
from funasr.modules.nets_utils import make_pad_mask
|
|
|
+from funasr.layers.abs_normalize import AbsNormalize
|
|
|
+from funasr.layers.inversible_interface import InversibleInterface
|
|
|
|
|
|
-class GlobalMVN(torch.nn.Module):
|
|
|
- """Apply global mean and variance normalization
|
|
|
|
|
|
+class GlobalMVN(AbsNormalize, InversibleInterface):
|
|
|
+ """Apply global mean and variance normalization
|
|
|
TODO(kamo): Make this class portable somehow
|
|
|
-
|
|
|
Args:
|
|
|
stats_file: npy file
|
|
|
norm_means: Apply mean normalization
|
|
|
@@ -63,7 +64,6 @@ class GlobalMVN(torch.nn.Module):
|
|
|
self, x: torch.Tensor, ilens: torch.Tensor = None
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""Forward function
|
|
|
-
|
|
|
Args:
|
|
|
x: (B, L, ...)
|
|
|
ilens: (B,)
|
|
|
@@ -115,4 +115,4 @@ class GlobalMVN(torch.nn.Module):
|
|
|
if norm_means:
|
|
|
x += self.mean
|
|
|
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
|
|
- return x, ilens
|
|
|
+ return x, ilens
|