frontend.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # -*- encoding: utf-8 -*-
  2. from pathlib import Path
  3. from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
  4. import copy
  5. import numpy as np
  6. import kaldi_native_fbank as knf
  7. root_dir = Path(__file__).resolve().parent
  8. logger_initialized = {}
  9. class WavFrontend():
  10. """Conventional frontend structure for ASR.
  11. """
  12. def __init__(
  13. self,
  14. cmvn_file: str = None,
  15. fs: int = 16000,
  16. window: str = 'hamming',
  17. n_mels: int = 80,
  18. frame_length: int = 25,
  19. frame_shift: int = 10,
  20. lfr_m: int = 1,
  21. lfr_n: int = 1,
  22. dither: float = 1.0,
  23. **kwargs,
  24. ) -> None:
  25. opts = knf.FbankOptions()
  26. opts.frame_opts.samp_freq = fs
  27. opts.frame_opts.dither = dither
  28. opts.frame_opts.window_type = window
  29. opts.frame_opts.frame_shift_ms = float(frame_shift)
  30. opts.frame_opts.frame_length_ms = float(frame_length)
  31. opts.mel_opts.num_bins = n_mels
  32. opts.energy_floor = 0
  33. opts.frame_opts.snip_edges = True
  34. opts.mel_opts.debug_mel = False
  35. self.opts = opts
  36. self.lfr_m = lfr_m
  37. self.lfr_n = lfr_n
  38. self.cmvn_file = cmvn_file
  39. if self.cmvn_file:
  40. self.cmvn = self.load_cmvn()
  41. self.fbank_fn = None
  42. self.fbank_beg_idx = 0
  43. self.reset_status()
  44. def fbank(self,
  45. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  46. waveform = waveform * (1 << 15)
  47. self.fbank_fn = knf.OnlineFbank(self.opts)
  48. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  49. frames = self.fbank_fn.num_frames_ready
  50. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  51. for i in range(frames):
  52. mat[i, :] = self.fbank_fn.get_frame(i)
  53. feat = mat.astype(np.float32)
  54. feat_len = np.array(mat.shape[0]).astype(np.int32)
  55. return feat, feat_len
  56. def fbank_online(self,
  57. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  58. waveform = waveform * (1 << 15)
  59. # self.fbank_fn = knf.OnlineFbank(self.opts)
  60. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  61. frames = self.fbank_fn.num_frames_ready
  62. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  63. for i in range(self.fbank_beg_idx, frames):
  64. mat[i, :] = self.fbank_fn.get_frame(i)
  65. # self.fbank_beg_idx += (frames-self.fbank_beg_idx)
  66. feat = mat.astype(np.float32)
  67. feat_len = np.array(mat.shape[0]).astype(np.int32)
  68. return feat, feat_len
  69. def reset_status(self):
  70. self.fbank_fn = knf.OnlineFbank(self.opts)
  71. self.fbank_beg_idx = 0
  72. def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  73. if self.lfr_m != 1 or self.lfr_n != 1:
  74. feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
  75. if self.cmvn_file:
  76. feat = self.apply_cmvn(feat)
  77. feat_len = np.array(feat.shape[0]).astype(np.int32)
  78. return feat, feat_len
  79. @staticmethod
  80. def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
  81. LFR_inputs = []
  82. T = inputs.shape[0]
  83. T_lfr = int(np.ceil(T / lfr_n))
  84. left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
  85. inputs = np.vstack((left_padding, inputs))
  86. T = T + (lfr_m - 1) // 2
  87. for i in range(T_lfr):
  88. if lfr_m <= T - i * lfr_n:
  89. LFR_inputs.append(
  90. (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
  91. else:
  92. # process last LFR frame
  93. num_padding = lfr_m - (T - i * lfr_n)
  94. frame = inputs[i * lfr_n:].reshape(-1)
  95. for _ in range(num_padding):
  96. frame = np.hstack((frame, inputs[-1]))
  97. LFR_inputs.append(frame)
  98. LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
  99. return LFR_outputs
  100. def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
  101. """
  102. Apply CMVN with mvn data
  103. """
  104. frame, dim = inputs.shape
  105. means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
  106. vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
  107. inputs = (inputs + means) * vars
  108. return inputs
  109. def load_cmvn(self,) -> np.ndarray:
  110. with open(self.cmvn_file, 'r', encoding='utf-8') as f:
  111. lines = f.readlines()
  112. means_list = []
  113. vars_list = []
  114. for i in range(len(lines)):
  115. line_item = lines[i].split()
  116. if line_item[0] == '<AddShift>':
  117. line_item = lines[i + 1].split()
  118. if line_item[0] == '<LearnRateCoef>':
  119. add_shift_line = line_item[3:(len(line_item) - 1)]
  120. means_list = list(add_shift_line)
  121. continue
  122. elif line_item[0] == '<Rescale>':
  123. line_item = lines[i + 1].split()
  124. if line_item[0] == '<LearnRateCoef>':
  125. rescale_line = line_item[3:(len(line_item) - 1)]
  126. vars_list = list(rescale_line)
  127. continue
  128. means = np.array(means_list).astype(np.float64)
  129. vars = np.array(vars_list).astype(np.float64)
  130. cmvn = np.array([means, vars])
  131. return cmvn
  132. class WavFrontendOnline(WavFrontend):
  133. def __init__(self, **kwargs):
  134. super().__init__(**kwargs)
  135. # self.fbank_fn = knf.OnlineFbank(self.opts)
  136. # add variables
  137. self.frame_sample_length = int(self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000)
  138. self.frame_shift_sample_length = int(self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000)
  139. self.waveform = None
  140. self.reserve_waveforms = None
  141. self.input_cache = None
  142. self.lfr_splice_cache = []
  143. @staticmethod
  144. # inputs has catted the cache
  145. def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
  146. np.ndarray, np.ndarray, int]:
  147. """
  148. Apply lfr with data
  149. """
  150. LFR_inputs = []
  151. T = inputs.shape[0] # include the right context
  152. T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
  153. splice_idx = T_lfr
  154. for i in range(T_lfr):
  155. if lfr_m <= T - i * lfr_n:
  156. LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
  157. else: # process last LFR frame
  158. if is_final:
  159. num_padding = lfr_m - (T - i * lfr_n)
  160. frame = (inputs[i * lfr_n:]).reshape(-1)
  161. for _ in range(num_padding):
  162. frame = np.hstack((frame, inputs[-1]))
  163. LFR_inputs.append(frame)
  164. else:
  165. # update splice_idx and break the circle
  166. splice_idx = i
  167. break
  168. splice_idx = min(T - 1, splice_idx * lfr_n)
  169. lfr_splice_cache = inputs[splice_idx:, :]
  170. LFR_outputs = np.vstack(LFR_inputs)
  171. return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
  172. @staticmethod
  173. def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
  174. frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
  175. return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
  176. def fbank(
  177. self,
  178. input: np.ndarray,
  179. input_lengths: np.ndarray
  180. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  181. self.fbank_fn = knf.OnlineFbank(self.opts)
  182. batch_size = input.shape[0]
  183. if self.input_cache is None:
  184. self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
  185. input = np.concatenate((self.input_cache, input), axis=1)
  186. frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
  187. # update self.in_cache
  188. self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
  189. waveforms = np.empty(0, dtype=np.float32)
  190. feats_pad = np.empty(0, dtype=np.float32)
  191. feats_lens = np.empty(0, dtype=np.int32)
  192. if frame_num:
  193. waveforms = []
  194. feats = []
  195. feats_lens = []
  196. for i in range(batch_size):
  197. waveform = input[i]
  198. waveforms.append(
  199. waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
  200. waveform = waveform * (1 << 15)
  201. self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  202. frames = self.fbank_fn.num_frames_ready
  203. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  204. for i in range(frames):
  205. mat[i, :] = self.fbank_fn.get_frame(i)
  206. feat = mat.astype(np.float32)
  207. feat_len = np.array(mat.shape[0]).astype(np.int32)
  208. feats.append(feat)
  209. feats_lens.append(feat_len)
  210. waveforms = np.stack(waveforms)
  211. feats_lens = np.array(feats_lens)
  212. feats_pad = np.array(feats)
  213. self.fbanks = feats_pad
  214. self.fbanks_lens = copy.deepcopy(feats_lens)
  215. return waveforms, feats_pad, feats_lens
  216. def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
  217. return self.fbanks, self.fbanks_lens
  218. def lfr_cmvn(
  219. self,
  220. input: np.ndarray,
  221. input_lengths: np.ndarray,
  222. is_final: bool = False
  223. ) -> Tuple[np.ndarray, np.ndarray, List[int]]:
  224. batch_size = input.shape[0]
  225. feats = []
  226. feats_lens = []
  227. lfr_splice_frame_idxs = []
  228. for i in range(batch_size):
  229. mat = input[i, :input_lengths[i], :]
  230. lfr_splice_frame_idx = -1
  231. if self.lfr_m != 1 or self.lfr_n != 1:
  232. # update self.lfr_splice_cache in self.apply_lfr
  233. mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
  234. is_final)
  235. if self.cmvn_file is not None:
  236. mat = self.apply_cmvn(mat)
  237. feat_length = mat.shape[0]
  238. feats.append(mat)
  239. feats_lens.append(feat_length)
  240. lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
  241. feats_lens = np.array(feats_lens)
  242. feats_pad = np.array(feats)
  243. return feats_pad, feats_lens, lfr_splice_frame_idxs
  244. def extract_fbank(
  245. self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
  246. ) -> Tuple[np.ndarray, np.ndarray]:
  247. batch_size = input.shape[0]
  248. assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
  249. waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
  250. if feats.shape[0]:
  251. self.waveforms = waveforms if self.reserve_waveforms is None else np.concatenate(
  252. (self.reserve_waveforms, waveforms), axis=1)
  253. if not self.lfr_splice_cache:
  254. for i in range(batch_size):
  255. self.lfr_splice_cache.append(np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0))
  256. if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
  257. lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
  258. feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
  259. feats_lengths += lfr_splice_cache_np[0].shape[0]
  260. frame_from_waveforms = int(
  261. (self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
  262. minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
  263. feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(feats, feats_lengths, is_final)
  264. if self.lfr_m == 1:
  265. self.reserve_waveforms = None
  266. else:
  267. reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
  268. # print('reserve_frame_idx: ' + str(reserve_frame_idx))
  269. # print('frame_frame: ' + str(frame_from_waveforms))
  270. self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
  271. sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
  272. self.waveforms = self.waveforms[:, :sample_length]
  273. else:
  274. # update self.reserve_waveforms and self.lfr_splice_cache
  275. self.reserve_waveforms = self.waveforms[:,
  276. :-(self.frame_sample_length - self.frame_shift_sample_length)]
  277. for i in range(batch_size):
  278. self.lfr_splice_cache[i] = np.concatenate((self.lfr_splice_cache[i], feats[i]), axis=0)
  279. return np.empty(0, dtype=np.float32), feats_lengths
  280. else:
  281. if is_final:
  282. self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
  283. feats = np.stack(self.lfr_splice_cache)
  284. feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
  285. feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
  286. if is_final:
  287. self.cache_reset()
  288. return feats, feats_lengths
  289. def get_waveforms(self):
  290. return self.waveforms
  291. def cache_reset(self):
  292. self.fbank_fn = knf.OnlineFbank(self.opts)
  293. self.reserve_waveforms = None
  294. self.input_cache = None
  295. self.lfr_splice_cache = []
  296. def load_bytes(input):
  297. middle_data = np.frombuffer(input, dtype=np.int16)
  298. middle_data = np.asarray(middle_data)
  299. if middle_data.dtype.kind not in 'iu':
  300. raise TypeError("'middle_data' must be an array of integers")
  301. dtype = np.dtype('float32')
  302. if dtype.kind != 'f':
  303. raise TypeError("'dtype' must be a floating point type")
  304. i = np.iinfo(middle_data.dtype)
  305. abs_max = 2 ** (i.bits - 1)
  306. offset = i.min + abs_max
  307. array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
  308. return array
  309. class SinusoidalPositionEncoderOnline():
  310. '''Streaming Positional encoding.
  311. '''
  312. def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
  313. batch_size = positions.shape[0]
  314. positions = positions.astype(dtype)
  315. log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
  316. inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
  317. inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
  318. scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
  319. encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
  320. return encoding.astype(dtype)
  321. def forward(self, x, start_idx=0):
  322. batch_size, timesteps, input_dim = x.shape
  323. positions = np.arange(1, timesteps+1+start_idx)[None, :]
  324. position_encoding = self.encode(positions, input_dim, x.dtype)
  325. return x + position_encoding[:, start_idx: start_idx + timesteps]
  326. def test():
  327. path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
  328. import librosa
  329. cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
  330. config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
  331. from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
  332. config = read_yaml(config_file)
  333. waveform, _ = librosa.load(path, sr=None)
  334. frontend = WavFrontend(
  335. cmvn_file=cmvn_file,
  336. **config['frontend_conf'],
  337. )
  338. speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy
  339. feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
  340. frontend.reset_status() # clear cache
  341. return feat, feat_len
  342. if __name__ == '__main__':
  343. test()