| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- # Part of the implementation is borrowed from espnet/espnet.
- from typing import Tuple
- import numpy as np
- import torch
- import torchaudio.compliance.kaldi as kaldi
- from funasr.models.frontend.abs_frontend import AbsFrontend
- from typeguard import check_argument_types
- from torch.nn.utils.rnn import pad_sequence
- import kaldi_native_fbank as knf
- class WavFrontend(AbsFrontend):
- """Conventional frontend structure for ASR.
- """
- def __init__(
- self,
- cmvn_file: str = None,
- fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
- frame_length: int = 25,
- frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
- lfr_m: int = 1,
- lfr_n: int = 1,
- dither: float = 1.0,
- snip_edges: bool = True,
- upsacle_samples: bool = True,
- ):
- assert check_argument_types()
- super().__init__()
- self.fs = fs
- self.window = window
- self.n_mels = n_mels
- self.frame_length = frame_length
- self.frame_shift = frame_shift
- self.filter_length_min = filter_length_min
- self.filter_length_max = filter_length_max
- self.lfr_m = lfr_m
- self.lfr_n = lfr_n
- self.cmvn_file = cmvn_file
- self.dither = dither
- self.snip_edges = snip_edges
- self.upsacle_samples = upsacle_samples
- def output_size(self) -> int:
- return self.n_mels * self.lfr_m
- def forward(
- self,
- input: torch.Tensor,
- input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- batch_size = input.size(0)
- feats = []
- feats_lens = []
- for i in range(batch_size):
- waveform_length = input_lengths[i]
- waveform = input[i][:waveform_length]
- waveform = waveform * (1 << 15)
- waveform = waveform.unsqueeze(0)
- mat = kaldi.fbank(waveform,
- num_mel_bins=self.n_mels,
- frame_length=self.frame_length,
- frame_shift=self.frame_shift,
- dither=self.dither,
- energy_floor=0.0,
- window_type=self.window,
- sample_frequency=self.fs)
- feat_length = mat.size(0)
- feats.append(mat)
- feats_lens.append(feat_length)
- feats_lens = torch.as_tensor(feats_lens)
- feats_pad = pad_sequence(feats,
- batch_first=True,
- padding_value=0.0)
- return feats_pad, feats_lens
- import kaldi_native_fbank as knf
- def fbank_knf(waveform):
- # sampling_rate = 16000
- # samples = torch.randn(16000 * 10)
- opts = knf.FbankOptions()
- opts.frame_opts.samp_freq = 16000
- opts.frame_opts.dither = 0.0
- opts.frame_opts.window_type = "hamming"
- opts.frame_opts.frame_shift_ms = 10.0
- opts.frame_opts.frame_length_ms = 25.0
- opts.mel_opts.num_bins = 80
- opts.energy_floor = 1
- opts.frame_opts.snip_edges = True
- opts.mel_opts.debug_mel = False
-
- fbank = knf.OnlineFbank(opts)
- waveform = waveform * (1 << 15)
- fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
- frames = fbank.num_frames_ready
- mat = np.empty([frames, opts.mel_opts.num_bins])
- for i in range(frames):
- mat[i, :] = fbank.get_frame(i)
- return mat
- if __name__ == '__main__':
- import librosa
-
- path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
- waveform, fs = librosa.load(path, sr=None)
- fbank = fbank_knf(waveform)
- frontend = WavFrontend(dither=0.0)
- waveform_tensor = torch.from_numpy(waveform)[None, :]
- fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
- fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
- diff = fbank - fbank_torch
- diff_max = diff.max()
- diff_sum = diff.abs().sum()
- pass
|