frontend.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # -*- encoding: utf-8 -*-
  2. from pathlib import Path
  3. from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
  4. import numpy as np
  5. import kaldi_native_fbank as knf
  6. root_dir = Path(__file__).resolve().parent
  7. logger_initialized = {}
  8. class WavFrontend():
  9. """Conventional frontend structure for ASR.
  10. """
  11. def __init__(
  12. self,
  13. cmvn_file: str = None,
  14. fs: int = 16000,
  15. window: str = 'hamming',
  16. n_mels: int = 80,
  17. frame_length: int = 25,
  18. frame_shift: int = 10,
  19. lfr_m: int = 1,
  20. lfr_n: int = 1,
  21. dither: float = 1.0,
  22. **kwargs,
  23. ) -> None:
  24. opts = knf.FbankOptions()
  25. opts.frame_opts.samp_freq = fs
  26. opts.frame_opts.dither = dither
  27. opts.frame_opts.window_type = window
  28. opts.frame_opts.frame_shift_ms = float(frame_shift)
  29. opts.frame_opts.frame_length_ms = float(frame_length)
  30. opts.mel_opts.num_bins = n_mels
  31. opts.energy_floor = 0
  32. opts.frame_opts.snip_edges = True
  33. opts.mel_opts.debug_mel = False
  34. self.opts = opts
  35. self.lfr_m = lfr_m
  36. self.lfr_n = lfr_n
  37. self.cmvn_file = cmvn_file
  38. if self.cmvn_file:
  39. self.cmvn = self.load_cmvn()
  40. self.fbank_fn = None
  41. self.fbank_beg_idx = 0
  42. self.reset_status()
  43. def fbank(self,
  44. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  45. waveform = waveform * (1 << 15)
  46. self.fbank_fn = knf.OnlineFbank(self.opts)
  47. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  48. frames = self.fbank_fn.num_frames_ready
  49. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  50. for i in range(frames):
  51. mat[i, :] = self.fbank_fn.get_frame(i)
  52. feat = mat.astype(np.float32)
  53. feat_len = np.array(mat.shape[0]).astype(np.int32)
  54. return feat, feat_len
  55. def fbank_online(self,
  56. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  57. waveform = waveform * (1 << 15)
  58. # self.fbank_fn = knf.OnlineFbank(self.opts)
  59. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  60. frames = self.fbank_fn.num_frames_ready
  61. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  62. for i in range(self.fbank_beg_idx, frames):
  63. mat[i, :] = self.fbank_fn.get_frame(i)
  64. # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
  65. feat = mat.astype(np.float32)
  66. feat_len = np.array(mat.shape[0]).astype(np.int32)
  67. return feat, feat_len
  68. def reset_status(self):
  69. self.fbank_fn = knf.OnlineFbank(self.opts)
  70. self.fbank_beg_idx = 0
  71. def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  72. if self.lfr_m != 1 or self.lfr_n != 1:
  73. feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
  74. if self.cmvn_file:
  75. feat = self.apply_cmvn(feat)
  76. feat_len = np.array(feat.shape[0]).astype(np.int32)
  77. return feat, feat_len
  78. @staticmethod
  79. def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
  80. LFR_inputs = []
  81. T = inputs.shape[0]
  82. T_lfr = int(np.ceil(T / lfr_n))
  83. left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
  84. inputs = np.vstack((left_padding, inputs))
  85. T = T + (lfr_m - 1) // 2
  86. for i in range(T_lfr):
  87. if lfr_m <= T - i * lfr_n:
  88. LFR_inputs.append(
  89. (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
  90. else:
  91. # process last LFR frame
  92. num_padding = lfr_m - (T - i * lfr_n)
  93. frame = inputs[i * lfr_n:].reshape(-1)
  94. for _ in range(num_padding):
  95. frame = np.hstack((frame, inputs[-1]))
  96. LFR_inputs.append(frame)
  97. LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
  98. return LFR_outputs
  99. def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
  100. """
  101. Apply CMVN with mvn data
  102. """
  103. frame, dim = inputs.shape
  104. means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
  105. vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
  106. inputs = (inputs + means) * vars
  107. return inputs
  108. def load_cmvn(self,) -> np.ndarray:
  109. with open(self.cmvn_file, 'r', encoding='utf-8') as f:
  110. lines = f.readlines()
  111. means_list = []
  112. vars_list = []
  113. for i in range(len(lines)):
  114. line_item = lines[i].split()
  115. if line_item[0] == '<AddShift>':
  116. line_item = lines[i + 1].split()
  117. if line_item[0] == '<LearnRateCoef>':
  118. add_shift_line = line_item[3:(len(line_item) - 1)]
  119. means_list = list(add_shift_line)
  120. continue
  121. elif line_item[0] == '<Rescale>':
  122. line_item = lines[i + 1].split()
  123. if line_item[0] == '<LearnRateCoef>':
  124. rescale_line = line_item[3:(len(line_item) - 1)]
  125. vars_list = list(rescale_line)
  126. continue
  127. means = np.array(means_list).astype(np.float64)
  128. vars = np.array(vars_list).astype(np.float64)
  129. cmvn = np.array([means, vars])
  130. return cmvn
  131. def load_bytes(input):
  132. middle_data = np.frombuffer(input, dtype=np.int16)
  133. middle_data = np.asarray(middle_data)
  134. if middle_data.dtype.kind not in 'iu':
  135. raise TypeError("'middle_data' must be an array of integers")
  136. dtype = np.dtype('float32')
  137. if dtype.kind != 'f':
  138. raise TypeError("'dtype' must be a floating point type")
  139. i = np.iinfo(middle_data.dtype)
  140. abs_max = 2 ** (i.bits - 1)
  141. offset = i.min + abs_max
  142. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  143. return array
  144. def test():
  145. path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
  146. import librosa
  147. cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
  148. config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
  149. from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
  150. config = read_yaml(config_file)
  151. waveform, _ = librosa.load(path, sr=None)
  152. frontend = WavFrontend(
  153. cmvn_file=cmvn_file,
  154. **config['frontend_conf'],
  155. )
  156. speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy
  157. feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
  158. frontend.reset_status() # clear cache
  159. return feat, feat_len
  160. if __name__ == '__main__':
  161. test()