| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263 |
- from typing import List
- from typing import Tuple
- from typing import Union
- import librosa
- import numpy as np
- import torch
- from torch_complex.tensor import ComplexTensor
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- class FeatureTransform(torch.nn.Module):
- def __init__(
- self,
- # Mel options,
- fs: int = 16000,
- n_fft: int = 512,
- n_mels: int = 80,
- fmin: float = 0.0,
- fmax: float = None,
- # Normalization
- stats_file: str = None,
- apply_uttmvn: bool = True,
- uttmvn_norm_means: bool = True,
- uttmvn_norm_vars: bool = False,
- ):
- super().__init__()
- self.apply_uttmvn = apply_uttmvn
- self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
- self.stats_file = stats_file
- if stats_file is not None:
- self.global_mvn = GlobalMVN(stats_file)
- else:
- self.global_mvn = None
- if self.apply_uttmvn is not None:
- self.uttmvn = UtteranceMVN(
- norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
- )
- else:
- self.uttmvn = None
- def forward(
- self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # (B, T, F) or (B, T, C, F)
- if x.dim() not in (3, 4):
- raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
- if not torch.is_tensor(ilens):
- ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
- if x.dim() == 4:
- # h: (B, T, C, F) -> h: (B, T, F)
- if self.training:
- # Select 1ch randomly
- ch = np.random.randint(x.size(2))
- h = x[:, :, ch, :]
- else:
- # Use the first channel
- h = x[:, :, 0, :]
- else:
- h = x
- # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
- h = h.real**2 + h.imag**2
- h, _ = self.logmel(h, ilens)
- if self.stats_file is not None:
- h, _ = self.global_mvn(h, ilens)
- if self.apply_uttmvn:
- h, _ = self.uttmvn(h, ilens)
- return h, ilens
- class LogMel(torch.nn.Module):
- """Convert STFT to fbank feats
- The arguments is same as librosa.filters.mel
- Args:
- fs: number > 0 [scalar] sampling rate of the incoming signal
- n_fft: int > 0 [scalar] number of FFT components
- n_mels: int > 0 [scalar] number of Mel bands to generate
- fmin: float >= 0 [scalar] lowest frequency (in Hz)
- fmax: float >= 0 [scalar] highest frequency (in Hz).
- If `None`, use `fmax = fs / 2.0`
- htk: use HTK formula instead of Slaney
- norm: {None, 1, np.inf} [scalar]
- if 1, divide the triangular mel weights by the width of the mel band
- (area normalization). Otherwise, leave all the triangles aiming for
- a peak value of 1.0
- """
- def __init__(
- self,
- fs: int = 16000,
- n_fft: int = 512,
- n_mels: int = 80,
- fmin: float = 0.0,
- fmax: float = None,
- htk: bool = False,
- norm=1,
- ):
- super().__init__()
- _mel_options = dict(
- sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
- )
- self.mel_options = _mel_options
- # Note(kamo): The mel matrix of librosa is different from kaldi.
- melmat = librosa.filters.mel(**_mel_options)
- # melmat: (D2, D1) -> (D1, D2)
- self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
- def extra_repr(self):
- return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
- def forward(
- self, feat: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
- mel_feat = torch.matmul(feat, self.melmat)
- logmel_feat = (mel_feat + 1e-20).log()
- # Zero padding
- logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
- return logmel_feat, ilens
- class GlobalMVN(torch.nn.Module):
- """Apply global mean and variance normalization
- Args:
- stats_file(str): npy file of 1-dim array or text file.
- From the _first element to
- the {(len(array) - 1) / 2}th element are treated as
- the sum of features,
- and the rest excluding the last elements are
- treated as the sum of the square value of features,
- and the last elements eqauls to the number of samples.
- std_floor(float):
- """
- def __init__(
- self,
- stats_file: str,
- norm_means: bool = True,
- norm_vars: bool = True,
- eps: float = 1.0e-20,
- ):
- super().__init__()
- self.norm_means = norm_means
- self.norm_vars = norm_vars
- self.stats_file = stats_file
- stats = np.load(stats_file)
- stats = stats.astype(float)
- assert (len(stats) - 1) % 2 == 0, stats.shape
- count = stats.flatten()[-1]
- mean = stats[: (len(stats) - 1) // 2] / count
- var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
- std = np.maximum(np.sqrt(var), eps)
- self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
- self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
- def extra_repr(self):
- return (
- f"stats_file={self.stats_file}, "
- f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
- )
- def forward(
- self, x: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- # feat: (B, T, D)
- if self.norm_means:
- x += self.bias.type_as(x)
- x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
- if self.norm_vars:
- x *= self.scale.type_as(x)
- return x, ilens
- class UtteranceMVN(torch.nn.Module):
- def __init__(
- self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
- ):
- super().__init__()
- self.norm_means = norm_means
- self.norm_vars = norm_vars
- self.eps = eps
- def extra_repr(self):
- return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
- def forward(
- self, x: torch.Tensor, ilens: torch.LongTensor
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- return utterance_mvn(
- x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
- )
- def utterance_mvn(
- x: torch.Tensor,
- ilens: torch.LongTensor,
- norm_means: bool = True,
- norm_vars: bool = False,
- eps: float = 1.0e-20,
- ) -> Tuple[torch.Tensor, torch.LongTensor]:
- """Apply utterance mean and variance normalization
- Args:
- x: (B, T, D), assumed zero padded
- ilens: (B, T, D)
- norm_means:
- norm_vars:
- eps:
- """
- ilens_ = ilens.type_as(x)
- # mean: (B, D)
- mean = x.sum(dim=1) / ilens_[:, None]
- if norm_means:
- x -= mean[:, None, :]
- x_ = x
- else:
- x_ = x - mean[:, None, :]
- # Zero padding
- x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
- if norm_vars:
- var = x_.pow(2).sum(dim=1) / ilens_[:, None]
- var = torch.clamp(var, min=eps)
- x /= var.sqrt()[:, None, :]
- x_ = x
- return x_, ilens
- def feature_transform_for(args, n_fft):
- return FeatureTransform(
- # Mel options,
- fs=args.fbank_fs,
- n_fft=n_fft,
- n_mels=args.n_mels,
- fmin=args.fbank_fmin,
- fmax=args.fbank_fmax,
- # Normalization
- stats_file=args.stats_file,
- apply_uttmvn=args.apply_uttmvn,
- uttmvn_norm_means=args.uttmvn_norm_means,
- uttmvn_norm_vars=args.uttmvn_norm_vars,
- )
|