fbank.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from espnet/espnet.
  3. from typing import Tuple
  4. import numpy as np
  5. import torch
  6. import torchaudio.compliance.kaldi as kaldi
  7. from funasr.models.frontend.abs_frontend import AbsFrontend
  8. from typeguard import check_argument_types
  9. from torch.nn.utils.rnn import pad_sequence
  10. import kaldi_native_fbank as knf
  11. class WavFrontend(AbsFrontend):
  12. """Conventional frontend structure for ASR.
  13. """
  14. def __init__(
  15. self,
  16. cmvn_file: str = None,
  17. fs: int = 16000,
  18. window: str = 'hamming',
  19. n_mels: int = 80,
  20. frame_length: int = 25,
  21. frame_shift: int = 10,
  22. filter_length_min: int = -1,
  23. filter_length_max: int = -1,
  24. lfr_m: int = 1,
  25. lfr_n: int = 1,
  26. dither: float = 1.0,
  27. snip_edges: bool = True,
  28. upsacle_samples: bool = True,
  29. ):
  30. assert check_argument_types()
  31. super().__init__()
  32. self.fs = fs
  33. self.window = window
  34. self.n_mels = n_mels
  35. self.frame_length = frame_length
  36. self.frame_shift = frame_shift
  37. self.filter_length_min = filter_length_min
  38. self.filter_length_max = filter_length_max
  39. self.lfr_m = lfr_m
  40. self.lfr_n = lfr_n
  41. self.cmvn_file = cmvn_file
  42. self.dither = dither
  43. self.snip_edges = snip_edges
  44. self.upsacle_samples = upsacle_samples
  45. def output_size(self) -> int:
  46. return self.n_mels * self.lfr_m
  47. def forward(
  48. self,
  49. input: torch.Tensor,
  50. input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  51. batch_size = input.size(0)
  52. feats = []
  53. feats_lens = []
  54. for i in range(batch_size):
  55. waveform_length = input_lengths[i]
  56. waveform = input[i][:waveform_length]
  57. waveform = waveform * (1 << 15)
  58. waveform = waveform.unsqueeze(0)
  59. mat = kaldi.fbank(waveform,
  60. num_mel_bins=self.n_mels,
  61. frame_length=self.frame_length,
  62. frame_shift=self.frame_shift,
  63. dither=self.dither,
  64. energy_floor=0.0,
  65. window_type=self.window,
  66. sample_frequency=self.fs)
  67. feat_length = mat.size(0)
  68. feats.append(mat)
  69. feats_lens.append(feat_length)
  70. feats_lens = torch.as_tensor(feats_lens)
  71. feats_pad = pad_sequence(feats,
  72. batch_first=True,
  73. padding_value=0.0)
  74. return feats_pad, feats_lens
  75. import kaldi_native_fbank as knf
  76. def fbank_knf(waveform):
  77. # sampling_rate = 16000
  78. # samples = torch.randn(16000 * 10)
  79. opts = knf.FbankOptions()
  80. opts.frame_opts.samp_freq = 16000
  81. opts.frame_opts.dither = 0.0
  82. opts.frame_opts.window_type = "hamming"
  83. opts.frame_opts.frame_shift_ms = 10.0
  84. opts.frame_opts.frame_length_ms = 25.0
  85. opts.mel_opts.num_bins = 80
  86. opts.energy_floor = 1
  87. opts.frame_opts.snip_edges = True
  88. opts.mel_opts.debug_mel = False
  89. fbank = knf.OnlineFbank(opts)
  90. waveform = waveform * (1 << 15)
  91. fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
  92. frames = fbank.num_frames_ready
  93. mat = np.empty([frames, opts.mel_opts.num_bins])
  94. for i in range(frames):
  95. mat[i, :] = fbank.get_frame(i)
  96. return mat
  97. if __name__ == '__main__':
  98. import librosa
  99. path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
  100. waveform, fs = librosa.load(path, sr=None)
  101. fbank = fbank_knf(waveform)
  102. frontend = WavFrontend(dither=0.0)
  103. waveform_tensor = torch.from_numpy(waveform)[None, :]
  104. fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
  105. fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
  106. diff = fbank - fbank_torch
  107. diff_max = diff.max()
  108. diff_sum = diff.abs().sum()
  109. pass