log_mel.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import librosa
  2. import torch
  3. from typing import Tuple
  4. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  5. class LogMel(torch.nn.Module):
  6. """Convert STFT to fbank feats
  7. The arguments is same as librosa.filters.mel
  8. Args:
  9. fs: number > 0 [scalar] sampling rate of the incoming signal
  10. n_fft: int > 0 [scalar] number of FFT components
  11. n_mels: int > 0 [scalar] number of Mel bands to generate
  12. fmin: float >= 0 [scalar] lowest frequency (in Hz)
  13. fmax: float >= 0 [scalar] highest frequency (in Hz).
  14. If `None`, use `fmax = fs / 2.0`
  15. htk: use HTK formula instead of Slaney
  16. """
  17. def __init__(
  18. self,
  19. fs: int = 16000,
  20. n_fft: int = 512,
  21. n_mels: int = 80,
  22. fmin: float = None,
  23. fmax: float = None,
  24. htk: bool = False,
  25. log_base: float = None,
  26. ):
  27. super().__init__()
  28. fmin = 0 if fmin is None else fmin
  29. fmax = fs / 2 if fmax is None else fmax
  30. _mel_options = dict(
  31. sr=fs,
  32. n_fft=n_fft,
  33. n_mels=n_mels,
  34. fmin=fmin,
  35. fmax=fmax,
  36. htk=htk,
  37. )
  38. self.mel_options = _mel_options
  39. self.log_base = log_base
  40. # Note(kamo): The mel matrix of librosa is different from kaldi.
  41. melmat = librosa.filters.mel(**_mel_options)
  42. # melmat: (D2, D1) -> (D1, D2)
  43. self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
  44. def extra_repr(self):
  45. return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
  46. def forward(
  47. self,
  48. feat: torch.Tensor,
  49. ilens: torch.Tensor = None,
  50. ) -> Tuple[torch.Tensor, torch.Tensor]:
  51. # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
  52. mel_feat = torch.matmul(feat, self.melmat)
  53. mel_feat = torch.clamp(mel_feat, min=1e-10)
  54. if self.log_base is None:
  55. logmel_feat = mel_feat.log()
  56. elif self.log_base == 2.0:
  57. logmel_feat = mel_feat.log2()
  58. elif self.log_base == 10.0:
  59. logmel_feat = mel_feat.log10()
  60. else:
  61. logmel_feat = mel_feat.log() / torch.log(self.log_base)
  62. # Zero padding
  63. if ilens is not None:
  64. logmel_feat = logmel_feat.masked_fill(
  65. make_pad_mask(ilens, logmel_feat, 1), 0.0
  66. )
  67. else:
  68. ilens = feat.new_full(
  69. [feat.size(0)], fill_value=feat.size(1), dtype=torch.long
  70. )
  71. return logmel_feat, ilens