| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import librosa
- import torch
- from typing import Tuple
- from funasr.models.transformer.utils.nets_utils import make_pad_mask
- class LogMel(torch.nn.Module):
- """Convert STFT to fbank feats
- The arguments is same as librosa.filters.mel
- Args:
- fs: number > 0 [scalar] sampling rate of the incoming signal
- n_fft: int > 0 [scalar] number of FFT components
- n_mels: int > 0 [scalar] number of Mel bands to generate
- fmin: float >= 0 [scalar] lowest frequency (in Hz)
- fmax: float >= 0 [scalar] highest frequency (in Hz).
- If `None`, use `fmax = fs / 2.0`
- htk: use HTK formula instead of Slaney
- """
- def __init__(
- self,
- fs: int = 16000,
- n_fft: int = 512,
- n_mels: int = 80,
- fmin: float = None,
- fmax: float = None,
- htk: bool = False,
- log_base: float = None,
- ):
- super().__init__()
- fmin = 0 if fmin is None else fmin
- fmax = fs / 2 if fmax is None else fmax
- _mel_options = dict(
- sr=fs,
- n_fft=n_fft,
- n_mels=n_mels,
- fmin=fmin,
- fmax=fmax,
- htk=htk,
- )
- self.mel_options = _mel_options
- self.log_base = log_base
- # Note(kamo): The mel matrix of librosa is different from kaldi.
- melmat = librosa.filters.mel(**_mel_options)
- # melmat: (D2, D1) -> (D1, D2)
- self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
- def extra_repr(self):
- return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
- def forward(
- self,
- feat: torch.Tensor,
- ilens: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
- mel_feat = torch.matmul(feat, self.melmat)
- mel_feat = torch.clamp(mel_feat, min=1e-10)
- if self.log_base is None:
- logmel_feat = mel_feat.log()
- elif self.log_base == 2.0:
- logmel_feat = mel_feat.log2()
- elif self.log_base == 10.0:
- logmel_feat = mel_feat.log10()
- else:
- logmel_feat = mel_feat.log() / torch.log(self.log_base)
- # Zero padding
- if ilens is not None:
- logmel_feat = logmel_feat.masked_fill(
- make_pad_mask(ilens, logmel_feat, 1), 0.0
- )
- else:
- ilens = feat.new_full(
- [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
- )
- return logmel_feat, ilens
|