frontend.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # -*- encoding: utf-8 -*-
  2. from pathlib import Path
  3. from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
  4. import numpy as np
  5. from typeguard import check_argument_types
  6. import kaldi_native_fbank as knf
  7. root_dir = Path(__file__).resolve().parent
  8. logger_initialized = {}
  9. class WavFrontend():
  10. """Conventional frontend structure for ASR.
  11. """
  12. def __init__(
  13. self,
  14. cmvn_file: str = None,
  15. fs: int = 16000,
  16. window: str = 'hamming',
  17. n_mels: int = 80,
  18. frame_length: int = 25,
  19. frame_shift: int = 10,
  20. filter_length_min: int = -1,
  21. filter_length_max: float = -1,
  22. lfr_m: int = 1,
  23. lfr_n: int = 1,
  24. dither: float = 1.0
  25. ) -> None:
  26. check_argument_types()
  27. opts = knf.FbankOptions()
  28. opts.frame_opts.samp_freq = fs
  29. opts.frame_opts.dither = dither
  30. opts.frame_opts.window_type = window
  31. opts.frame_opts.frame_shift_ms = float(frame_shift)
  32. opts.frame_opts.frame_length_ms = float(frame_length)
  33. opts.mel_opts.num_bins = n_mels
  34. opts.energy_floor = 0
  35. opts.frame_opts.snip_edges = True
  36. opts.mel_opts.debug_mel = False
  37. self.opts = opts
  38. self.filter_length_min = filter_length_min
  39. self.filter_length_max = filter_length_max
  40. self.lfr_m = lfr_m
  41. self.lfr_n = lfr_n
  42. self.cmvn_file = cmvn_file
  43. if self.cmvn_file:
  44. self.cmvn = self.load_cmvn()
  45. def fbank(self,
  46. waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  47. waveform = waveform * (1 << 15)
  48. fbank_fn = knf.OnlineFbank(self.opts)
  49. fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
  50. frames = fbank_fn.num_frames_ready
  51. mat = np.empty([frames, self.opts.mel_opts.num_bins])
  52. for i in range(frames):
  53. mat[i, :] = fbank_fn.get_frame(i)
  54. feat = mat.astype(np.float32)
  55. feat_len = np.array(mat.shape[0]).astype(np.int32)
  56. return feat, feat_len
  57. def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
  58. if self.lfr_m != 1 or self.lfr_n != 1:
  59. feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
  60. if self.cmvn_file:
  61. feat = self.apply_cmvn(feat)
  62. feat_len = np.array(feat.shape[0]).astype(np.int32)
  63. return feat, feat_len
  64. @staticmethod
  65. def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
  66. LFR_inputs = []
  67. T = inputs.shape[0]
  68. T_lfr = int(np.ceil(T / lfr_n))
  69. left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
  70. inputs = np.vstack((left_padding, inputs))
  71. T = T + (lfr_m - 1) // 2
  72. for i in range(T_lfr):
  73. if lfr_m <= T - i * lfr_n:
  74. LFR_inputs.append(
  75. (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1))
  76. else:
  77. # process last LFR frame
  78. num_padding = lfr_m - (T - i * lfr_n)
  79. frame = inputs[i * lfr_n:].reshape(-1)
  80. for _ in range(num_padding):
  81. frame = np.hstack((frame, inputs[-1]))
  82. LFR_inputs.append(frame)
  83. LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
  84. return LFR_outputs
  85. def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
  86. """
  87. Apply CMVN with mvn data
  88. """
  89. frame, dim = inputs.shape
  90. means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
  91. vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
  92. inputs = (inputs + means) * vars
  93. return inputs
  94. def load_cmvn(self,) -> np.ndarray:
  95. with open(self.cmvn_file, 'r', encoding='utf-8') as f:
  96. lines = f.readlines()
  97. means_list = []
  98. vars_list = []
  99. for i in range(len(lines)):
  100. line_item = lines[i].split()
  101. if line_item[0] == '<AddShift>':
  102. line_item = lines[i + 1].split()
  103. if line_item[0] == '<LearnRateCoef>':
  104. add_shift_line = line_item[3:(len(line_item) - 1)]
  105. means_list = list(add_shift_line)
  106. continue
  107. elif line_item[0] == '<Rescale>':
  108. line_item = lines[i + 1].split()
  109. if line_item[0] == '<LearnRateCoef>':
  110. rescale_line = line_item[3:(len(line_item) - 1)]
  111. vars_list = list(rescale_line)
  112. continue
  113. means = np.array(means_list).astype(np.float64)
  114. vars = np.array(vars_list).astype(np.float64)
  115. cmvn = np.array([means, vars])
  116. return cmvn