wav_frontend.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Part of the implementation is borrowed from espnet/espnet.
  3. from typing import Tuple
  4. import copy
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torchaudio.compliance.kaldi as kaldi
  9. from torch.nn.utils.rnn import pad_sequence
  10. import funasr.frontends.eend_ola_feature as eend_ola_feature
  11. from funasr.register import tables
  12. def load_cmvn(cmvn_file):
  13. with open(cmvn_file, 'r', encoding='utf-8') as f:
  14. lines = f.readlines()
  15. means_list = []
  16. vars_list = []
  17. for i in range(len(lines)):
  18. line_item = lines[i].split()
  19. if line_item[0] == '<AddShift>':
  20. line_item = lines[i + 1].split()
  21. if line_item[0] == '<LearnRateCoef>':
  22. add_shift_line = line_item[3:(len(line_item) - 1)]
  23. means_list = list(add_shift_line)
  24. continue
  25. elif line_item[0] == '<Rescale>':
  26. line_item = lines[i + 1].split()
  27. if line_item[0] == '<LearnRateCoef>':
  28. rescale_line = line_item[3:(len(line_item) - 1)]
  29. vars_list = list(rescale_line)
  30. continue
  31. means = np.array(means_list).astype(np.float32)
  32. vars = np.array(vars_list).astype(np.float32)
  33. cmvn = np.array([means, vars])
  34. cmvn = torch.as_tensor(cmvn, dtype=torch.float32)
  35. return cmvn
  36. def apply_cmvn(inputs, cmvn): # noqa
  37. """
  38. Apply CMVN with mvn data
  39. """
  40. device = inputs.device
  41. dtype = inputs.dtype
  42. frame, dim = inputs.shape
  43. means = cmvn[0:1, :dim]
  44. vars = cmvn[1:2, :dim]
  45. inputs += means.to(device)
  46. inputs *= vars.to(device)
  47. return inputs.type(torch.float32)
  48. def apply_lfr(inputs, lfr_m, lfr_n):
  49. LFR_inputs = []
  50. T = inputs.shape[0]
  51. T_lfr = int(np.ceil(T / lfr_n))
  52. left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
  53. inputs = torch.vstack((left_padding, inputs))
  54. T = T + (lfr_m - 1) // 2
  55. for i in range(T_lfr):
  56. if lfr_m <= T - i * lfr_n:
  57. LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
  58. else: # process last LFR frame
  59. num_padding = lfr_m - (T - i * lfr_n)
  60. frame = (inputs[i * lfr_n:]).view(-1)
  61. for _ in range(num_padding):
  62. frame = torch.hstack((frame, inputs[-1]))
  63. LFR_inputs.append(frame)
  64. LFR_outputs = torch.vstack(LFR_inputs)
  65. return LFR_outputs.type(torch.float32)
  66. @tables.register("frontend_classes", "WavFrontend")
  67. class WavFrontend(nn.Module):
  68. """Conventional frontend structure for ASR.
  69. """
  70. def __init__(
  71. self,
  72. cmvn_file: str = None,
  73. fs: int = 16000,
  74. window: str = 'hamming',
  75. n_mels: int = 80,
  76. frame_length: int = 25,
  77. frame_shift: int = 10,
  78. filter_length_min: int = -1,
  79. filter_length_max: int = -1,
  80. lfr_m: int = 1,
  81. lfr_n: int = 1,
  82. dither: float = 1.0,
  83. snip_edges: bool = True,
  84. upsacle_samples: bool = True,
  85. **kwargs,
  86. ):
  87. super().__init__()
  88. self.fs = fs
  89. self.window = window
  90. self.n_mels = n_mels
  91. self.frame_length = frame_length
  92. self.frame_shift = frame_shift
  93. self.filter_length_min = filter_length_min
  94. self.filter_length_max = filter_length_max
  95. self.lfr_m = lfr_m
  96. self.lfr_n = lfr_n
  97. self.cmvn_file = cmvn_file
  98. self.dither = dither
  99. self.snip_edges = snip_edges
  100. self.upsacle_samples = upsacle_samples
  101. self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
  102. def output_size(self) -> int:
  103. return self.n_mels * self.lfr_m
  104. def forward(
  105. self,
  106. input: torch.Tensor,
  107. input_lengths,
  108. **kwargs,
  109. ) -> Tuple[torch.Tensor, torch.Tensor]:
  110. batch_size = input.size(0)
  111. feats = []
  112. feats_lens = []
  113. for i in range(batch_size):
  114. waveform_length = input_lengths[i]
  115. waveform = input[i][:waveform_length]
  116. if self.upsacle_samples:
  117. waveform = waveform * (1 << 15)
  118. waveform = waveform.unsqueeze(0)
  119. mat = kaldi.fbank(waveform,
  120. num_mel_bins=self.n_mels,
  121. frame_length=self.frame_length,
  122. frame_shift=self.frame_shift,
  123. dither=self.dither,
  124. energy_floor=0.0,
  125. window_type=self.window,
  126. sample_frequency=self.fs,
  127. snip_edges=self.snip_edges)
  128. if self.lfr_m != 1 or self.lfr_n != 1:
  129. mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
  130. if self.cmvn is not None:
  131. mat = apply_cmvn(mat, self.cmvn)
  132. feat_length = mat.size(0)
  133. feats.append(mat)
  134. feats_lens.append(feat_length)
  135. feats_lens = torch.as_tensor(feats_lens)
  136. if batch_size == 1:
  137. feats_pad = feats[0][None, :, :]
  138. else:
  139. feats_pad = pad_sequence(feats,
  140. batch_first=True,
  141. padding_value=0.0)
  142. return feats_pad, feats_lens
  143. def forward_fbank(
  144. self,
  145. input: torch.Tensor,
  146. input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  147. batch_size = input.size(0)
  148. feats = []
  149. feats_lens = []
  150. for i in range(batch_size):
  151. waveform_length = input_lengths[i]
  152. waveform = input[i][:waveform_length]
  153. waveform = waveform * (1 << 15)
  154. waveform = waveform.unsqueeze(0)
  155. mat = kaldi.fbank(waveform,
  156. num_mel_bins=self.n_mels,
  157. frame_length=self.frame_length,
  158. frame_shift=self.frame_shift,
  159. dither=self.dither,
  160. energy_floor=0.0,
  161. window_type=self.window,
  162. sample_frequency=self.fs)
  163. feat_length = mat.size(0)
  164. feats.append(mat)
  165. feats_lens.append(feat_length)
  166. feats_lens = torch.as_tensor(feats_lens)
  167. feats_pad = pad_sequence(feats,
  168. batch_first=True,
  169. padding_value=0.0)
  170. return feats_pad, feats_lens
  171. def forward_lfr_cmvn(
  172. self,
  173. input: torch.Tensor,
  174. input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  175. batch_size = input.size(0)
  176. feats = []
  177. feats_lens = []
  178. for i in range(batch_size):
  179. mat = input[i, :input_lengths[i], :]
  180. if self.lfr_m != 1 or self.lfr_n != 1:
  181. mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
  182. if self.cmvn is not None:
  183. mat = apply_cmvn(mat, self.cmvn)
  184. feat_length = mat.size(0)
  185. feats.append(mat)
  186. feats_lens.append(feat_length)
  187. feats_lens = torch.as_tensor(feats_lens)
  188. feats_pad = pad_sequence(feats,
  189. batch_first=True,
  190. padding_value=0.0)
  191. return feats_pad, feats_lens
  192. @tables.register("frontend_classes", "WavFrontendOnline")
  193. class WavFrontendOnline(nn.Module):
  194. """Conventional frontend structure for streaming ASR/VAD.
  195. """
  196. def __init__(
  197. self,
  198. cmvn_file: str = None,
  199. fs: int = 16000,
  200. window: str = 'hamming',
  201. n_mels: int = 80,
  202. frame_length: int = 25,
  203. frame_shift: int = 10,
  204. filter_length_min: int = -1,
  205. filter_length_max: int = -1,
  206. lfr_m: int = 1,
  207. lfr_n: int = 1,
  208. dither: float = 1.0,
  209. snip_edges: bool = True,
  210. upsacle_samples: bool = True,
  211. **kwargs,
  212. ):
  213. super().__init__()
  214. self.fs = fs
  215. self.window = window
  216. self.n_mels = n_mels
  217. self.frame_length = frame_length
  218. self.frame_shift = frame_shift
  219. self.frame_sample_length = int(self.frame_length * self.fs / 1000)
  220. self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
  221. self.filter_length_min = filter_length_min
  222. self.filter_length_max = filter_length_max
  223. self.lfr_m = lfr_m
  224. self.lfr_n = lfr_n
  225. self.cmvn_file = cmvn_file
  226. self.dither = dither
  227. self.snip_edges = snip_edges
  228. self.upsacle_samples = upsacle_samples
  229. # self.waveforms = None
  230. # self.reserve_waveforms = None
  231. # self.fbanks = None
  232. # self.fbanks_lens = None
  233. self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
  234. # self.input_cache = None
  235. # self.lfr_splice_cache = []
  236. def output_size(self) -> int:
  237. return self.n_mels * self.lfr_m
  238. @staticmethod
  239. def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
  240. """
  241. Apply CMVN with mvn data
  242. """
  243. device = inputs.device
  244. dtype = inputs.dtype
  245. frame, dim = inputs.shape
  246. means = np.tile(cmvn[0:1, :dim], (frame, 1))
  247. vars = np.tile(cmvn[1:2, :dim], (frame, 1))
  248. inputs += torch.from_numpy(means).type(dtype).to(device)
  249. inputs *= torch.from_numpy(vars).type(dtype).to(device)
  250. return inputs.type(torch.float32)
  251. @staticmethod
  252. def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
  253. torch.Tensor, torch.Tensor, int]:
  254. """
  255. Apply lfr with data
  256. """
  257. LFR_inputs = []
  258. # inputs = torch.vstack((inputs_lfr_cache, inputs))
  259. T = inputs.shape[0] # include the right context
  260. T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
  261. splice_idx = T_lfr
  262. for i in range(T_lfr):
  263. if lfr_m <= T - i * lfr_n:
  264. LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
  265. else: # process last LFR frame
  266. if is_final:
  267. num_padding = lfr_m - (T - i * lfr_n)
  268. frame = (inputs[i * lfr_n:]).view(-1)
  269. for _ in range(num_padding):
  270. frame = torch.hstack((frame, inputs[-1]))
  271. LFR_inputs.append(frame)
  272. else:
  273. # update splice_idx and break the circle
  274. splice_idx = i
  275. break
  276. splice_idx = min(T - 1, splice_idx * lfr_n)
  277. lfr_splice_cache = inputs[splice_idx:, :]
  278. LFR_outputs = torch.vstack(LFR_inputs)
  279. return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
  280. @staticmethod
  281. def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
  282. frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
  283. return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
  284. def forward_fbank(
  285. self,
  286. input: torch.Tensor,
  287. input_lengths: torch.Tensor,
  288. cache: dict = {},
  289. **kwargs,
  290. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  291. batch_size = input.size(0)
  292. input = torch.cat((cache["input_cache"], input), dim=1)
  293. frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
  294. # update self.in_cache
  295. cache["input_cache"] = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
  296. waveforms = torch.empty(0)
  297. feats_pad = torch.empty(0)
  298. feats_lens = torch.empty(0)
  299. if frame_num:
  300. waveforms = []
  301. feats = []
  302. feats_lens = []
  303. for i in range(batch_size):
  304. waveform = input[i]
  305. # we need accurate wave samples that used for fbank extracting
  306. waveforms.append(
  307. waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
  308. waveform = waveform * (1 << 15)
  309. waveform = waveform.unsqueeze(0)
  310. mat = kaldi.fbank(waveform,
  311. num_mel_bins=self.n_mels,
  312. frame_length=self.frame_length,
  313. frame_shift=self.frame_shift,
  314. dither=self.dither,
  315. energy_floor=0.0,
  316. window_type=self.window,
  317. sample_frequency=self.fs)
  318. feat_length = mat.size(0)
  319. feats.append(mat)
  320. feats_lens.append(feat_length)
  321. waveforms = torch.stack(waveforms)
  322. feats_lens = torch.as_tensor(feats_lens)
  323. feats_pad = pad_sequence(feats,
  324. batch_first=True,
  325. padding_value=0.0)
  326. cache["fbanks"] = feats_pad
  327. cache["fbanks_lens"]= copy.deepcopy(feats_lens)
  328. return waveforms, feats_pad, feats_lens
  329. def forward_lfr_cmvn(
  330. self,
  331. input: torch.Tensor,
  332. input_lengths: torch.Tensor,
  333. is_final: bool = False,
  334. cache: dict = {},
  335. **kwargs,
  336. ):
  337. batch_size = input.size(0)
  338. feats = []
  339. feats_lens = []
  340. lfr_splice_frame_idxs = []
  341. for i in range(batch_size):
  342. mat = input[i, :input_lengths[i], :]
  343. if self.lfr_m != 1 or self.lfr_n != 1:
  344. # update self.lfr_splice_cache in self.apply_lfr
  345. # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
  346. mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
  347. is_final)
  348. if self.cmvn_file is not None:
  349. mat = self.apply_cmvn(mat, self.cmvn)
  350. feat_length = mat.size(0)
  351. feats.append(mat)
  352. feats_lens.append(feat_length)
  353. lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
  354. feats_lens = torch.as_tensor(feats_lens)
  355. feats_pad = pad_sequence(feats,
  356. batch_first=True,
  357. padding_value=0.0)
  358. lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
  359. return feats_pad, feats_lens, lfr_splice_frame_idxs
  360. def forward(
  361. self, input: torch.Tensor, input_lengths: torch.Tensor, **kwargs
  362. ):
  363. is_final = kwargs.get("is_final", False)
  364. cache = kwargs.get("cache", {})
  365. if len(cache) == 0:
  366. self.init_cache(cache)
  367. batch_size = input.shape[0]
  368. assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
  369. waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths, cache=cache) # input shape: B T D
  370. if feats.shape[0]:
  371. cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)
  372. if not cache["lfr_splice_cache"]: # 初始化splice_cache
  373. for i in range(batch_size):
  374. cache["lfr_splice_cache"].append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
  375. # need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
  376. if feats_lengths[0] + cache["lfr_splice_cache"][0].shape[0] >= self.lfr_m:
  377. lfr_splice_cache_tensor = torch.stack(cache["lfr_splice_cache"]) # B T D
  378. feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
  379. feats_lengths += lfr_splice_cache_tensor[0].shape[0]
  380. frame_from_waveforms = int(
  381. (cache["waveforms"].shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
  382. minus_frame = (self.lfr_m - 1) // 2 if cache["reserve_waveforms"].numel() == 0 else 0
  383. feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
  384. if self.lfr_m == 1:
  385. cache["reserve_waveforms"] = torch.empty(0)
  386. else:
  387. reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
  388. # print('reserve_frame_idx: ' + str(reserve_frame_idx))
  389. # print('frame_frame: ' + str(frame_from_waveforms))
  390. cache["reserve_waveforms"] = cache["waveforms"][:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
  391. sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
  392. cache["waveforms"] = cache["waveforms"][:, :sample_length]
  393. else:
  394. # update self.reserve_waveforms and self.lfr_splice_cache
  395. cache["reserve_waveforms"] = cache["waveforms"][:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
  396. for i in range(batch_size):
  397. cache["lfr_splice_cache"][i] = torch.cat((cache["lfr_splice_cache"][i], feats[i]), dim=0)
  398. return torch.empty(0), feats_lengths
  399. else:
  400. if is_final:
  401. cache["waveforms"] = waveforms if cache["reserve_waveforms"].numel() == 0 else cache["reserve_waveforms"]
  402. feats = torch.stack(cache["lfr_splice_cache"])
  403. feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
  404. feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
  405. # if is_final:
  406. # self.init_cache(cache)
  407. return feats, feats_lengths
  408. def init_cache(self, cache: dict = {}):
  409. cache["reserve_waveforms"] = torch.empty(0)
  410. cache["input_cache"] = torch.empty(0)
  411. cache["lfr_splice_cache"] = []
  412. cache["waveforms"] = None
  413. cache["fbanks"] = None
  414. cache["fbanks_lens"] = None
  415. return cache
  416. class WavFrontendMel23(nn.Module):
  417. """Conventional frontend structure for ASR.
  418. """
  419. def __init__(
  420. self,
  421. fs: int = 16000,
  422. frame_length: int = 25,
  423. frame_shift: int = 10,
  424. lfr_m: int = 1,
  425. lfr_n: int = 1,
  426. **kwargs,
  427. ):
  428. super().__init__()
  429. self.fs = fs
  430. self.frame_length = frame_length
  431. self.frame_shift = frame_shift
  432. self.lfr_m = lfr_m
  433. self.lfr_n = lfr_n
  434. self.n_mels = 23
  435. def output_size(self) -> int:
  436. return self.n_mels * (2 * self.lfr_m + 1)
  437. def forward(
  438. self,
  439. input: torch.Tensor,
  440. input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  441. batch_size = input.size(0)
  442. feats = []
  443. feats_lens = []
  444. for i in range(batch_size):
  445. waveform_length = input_lengths[i]
  446. waveform = input[i][:waveform_length]
  447. waveform = waveform.numpy()
  448. mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
  449. mat = eend_ola_feature.transform(mat)
  450. mat = eend_ola_feature.splice(mat, context_size=self.lfr_m)
  451. mat = mat[::self.lfr_n]
  452. mat = torch.from_numpy(mat)
  453. feat_length = mat.size(0)
  454. feats.append(mat)
  455. feats_lens.append(feat_length)
  456. feats_lens = torch.as_tensor(feats_lens)
  457. feats_pad = pad_sequence(feats,
  458. batch_first=True,
  459. padding_value=0.0)
  460. return feats_pad, feats_lens