frontend.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # -*- encoding: utf-8 -*-
  2. from pathlib import Path
  3. from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
  4. import numpy as np
  5. from typeguard import check_argument_types
  6. import kaldi_native_fbank as knf
  7. root_dir = Path(__file__).resolve().parent
  8. logger_initialized = {}
  9. class WavFrontend():
  10. """Conventional frontend structure for ASR.
  11. """
  12. def __init__(
  13. self,
  14. cmvn_file: str = None,
  15. fs: int = 16000,
  16. window: str = 'hamming',
  17. n_mels: int = 80,
  18. frame_length: int = 25,
  19. frame_shift: int = 10,
  20. lfr_m: int = 1,
  21. lfr_n: int = 1,
  22. dither: float = 1.0,
  23. **kwargs,
  24. ) -> None:
  25. check_argument_types()
  26. opts = knf.FbankOptions()
  27. opts.frame_opts.samp_freq = fs
  28. opts.frame_opts.dither = dither
  29. opts.frame_opts.window_type = window
  30. opts.frame_opts.frame_shift_ms = float(frame_shift)
  31. opts.frame_opts.frame_length_ms = float(frame_length)
  32. opts.mel_opts.num_bins = n_mels
  33. opts.energy_floor = 0
  34. opts.frame_opts.snip_edges = True
  35. opts.mel_opts.debug_mel = False
  36. self.opts = opts
  37. self.lfr_m = lfr_m
  38. self.lfr_n = lfr_n
  39. self.cmvn_file = cmvn_file
  40. if self.cmvn_file:
  41. self.cmvn = self.load_cmvn()
  42. self.fbank_fn = None
  43. self.fbank_beg_idx = 0
  44. self.reset_status()
  45. def fbank(self,
  46. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  47. waveform = waveform * (1 << 15)
  48. self.fbank_fn = knf.OnlineFbank(self.opts)
  49. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  50. frames = self.fbank_fn.num_frames_ready
  51. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  52. for i in range(frames):
  53. mat[i, :] = self.fbank_fn.get_frame(i)
  54. feat = mat.astype(np.float32)
  55. feat_len = np.array(mat.shape[0]).astype(np.int32)
  56. return feat, feat_len
  57. def fbank_online(self,
  58. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  59. waveform = waveform * (1 << 15)
  60. # self.fbank_fn = knf.OnlineFbank(self.opts)
  61. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  62. frames = self.fbank_fn.num_frames_ready
  63. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  64. for i in range(self.fbank_beg_idx, frames):
  65. mat[i, :] = self.fbank_fn.get_frame(i)
  66. # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
  67. feat = mat.astype(np.float32)
  68. feat_len = np.array(mat.shape[0]).astype(np.int32)
  69. return feat, feat_len
  70. def reset_status(self):
  71. self.fbank_fn = knf.OnlineFbank(self.opts)
  72. self.fbank_beg_idx = 0
  73. def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  74. if self.lfr_m != 1 or self.lfr_n != 1:
  75. feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
  76. if self.cmvn_file:
  77. feat = self.apply_cmvn(feat)
  78. feat_len = np.array(feat.shape[0]).astype(np.int32)
  79. return feat, feat_len
  80. @staticmethod
  81. def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
  82. LFR_inputs = []
  83. T = inputs.shape[0]
  84. T_lfr = int(np.ceil(T / lfr_n))
  85. left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
  86. inputs = np.vstack((left_padding, inputs))
  87. T = T + (lfr_m - 1) // 2
  88. for i in range(T_lfr):
  89. if lfr_m <= T - i * lfr_n:
  90. LFR_inputs.append(
  91. (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
  92. else:
  93. # process last LFR frame
  94. num_padding = lfr_m - (T - i * lfr_n)
  95. frame = inputs[i * lfr_n:].reshape(-1)
  96. for _ in range(num_padding):
  97. frame = np.hstack((frame, inputs[-1]))
  98. LFR_inputs.append(frame)
  99. LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
  100. return LFR_outputs
  101. def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
  102. """
  103. Apply CMVN with mvn data
  104. """
  105. frame, dim = inputs.shape
  106. means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
  107. vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
  108. inputs = (inputs + means) * vars
  109. return inputs
  110. def load_cmvn(self,) -> np.ndarray:
  111. with open(self.cmvn_file, 'r', encoding='utf-8') as f:
  112. lines = f.readlines()
  113. means_list = []
  114. vars_list = []
  115. for i in range(len(lines)):
  116. line_item = lines[i].split()
  117. if line_item[0] == '<AddShift>':
  118. line_item = lines[i + 1].split()
  119. if line_item[0] == '<LearnRateCoef>':
  120. add_shift_line = line_item[3:(len(line_item) - 1)]
  121. means_list = list(add_shift_line)
  122. continue
  123. elif line_item[0] == '<Rescale>':
  124. line_item = lines[i + 1].split()
  125. if line_item[0] == '<LearnRateCoef>':
  126. rescale_line = line_item[3:(len(line_item) - 1)]
  127. vars_list = list(rescale_line)
  128. continue
  129. means = np.array(means_list).astype(np.float64)
  130. vars = np.array(vars_list).astype(np.float64)
  131. cmvn = np.array([means, vars])
  132. return cmvn
  133. def load_bytes(input):
  134. middle_data = np.frombuffer(input, dtype=np.int16)
  135. middle_data = np.asarray(middle_data)
  136. if middle_data.dtype.kind not in 'iu':
  137. raise TypeError("'middle_data' must be an array of integers")
  138. dtype = np.dtype('float32')
  139. if dtype.kind != 'f':
  140. raise TypeError("'dtype' must be a floating point type")
  141. i = np.iinfo(middle_data.dtype)
  142. abs_max = 2 ** (i.bits - 1)
  143. offset = i.min + abs_max
  144. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  145. return array
  146. def test():
  147. path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
  148. import librosa
  149. cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
  150. config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
  151. from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
  152. config = read_yaml(config_file)
  153. waveform, _ = librosa.load(path, sr=None)
  154. frontend = WavFrontend(
  155. cmvn_file=cmvn_file,
  156. **config['frontend_conf'],
  157. )
  158. speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy
  159. feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
  160. frontend.reset_status() # clear cache
  161. return feat, feat_len
  162. if __name__ == '__main__':
  163. test()