vad_bin.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # -*- encoding: utf-8 -*-
  2. import os.path
  3. from pathlib import Path
  4. from typing import List, Union, Tuple
  5. import copy
  6. import librosa
  7. import numpy as np
  8. from .utils.utils import (ONNXRuntimeError,
  9. OrtInferSession, get_logger,
  10. read_yaml)
  11. from .utils.frontend import WavFrontend, WavFrontendOnline
  12. from .utils.e2e_vad import E2EVadModel
  13. logging = get_logger()
  14. class Fsmn_vad():
  15. """
  16. Author: Speech Lab of DAMO Academy, Alibaba Group
  17. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  18. https://arxiv.org/abs/1803.05030
  19. """
  20. def __init__(self, model_dir: Union[str, Path] = None,
  21. batch_size: int = 1,
  22. device_id: Union[str, int] = "-1",
  23. quantize: bool = False,
  24. intra_op_num_threads: int = 4,
  25. max_end_sil: int = None,
  26. ):
  27. if not Path(model_dir).exists():
  28. raise FileNotFoundError(f'{model_dir} does not exist.')
  29. model_file = os.path.join(model_dir, 'model.onnx')
  30. if quantize:
  31. model_file = os.path.join(model_dir, 'model_quant.onnx')
  32. config_file = os.path.join(model_dir, 'vad.yaml')
  33. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  34. config = read_yaml(config_file)
  35. self.frontend = WavFrontend(
  36. cmvn_file=cmvn_file,
  37. **config['frontend_conf']
  38. )
  39. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  40. self.batch_size = batch_size
  41. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  42. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  43. self.encoder_conf = config["encoder_conf"]
  44. def prepare_cache(self, in_cache: list = []):
  45. if len(in_cache) > 0:
  46. return in_cache
  47. fsmn_layers = self.encoder_conf["fsmn_layers"]
  48. proj_dim = self.encoder_conf["proj_dim"]
  49. lorder = self.encoder_conf["lorder"]
  50. for i in range(fsmn_layers):
  51. cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
  52. in_cache.append(cache)
  53. return in_cache
  54. def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
  55. waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
  56. waveform_nums = len(waveform_list)
  57. is_final = kwargs.get('kwargs', False)
  58. segments = [[]] * self.batch_size
  59. for beg_idx in range(0, waveform_nums, self.batch_size):
  60. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  61. waveform = waveform_list[beg_idx:end_idx]
  62. feats, feats_len = self.extract_feat(waveform)
  63. waveform = np.array(waveform)
  64. param_dict = kwargs.get('param_dict', dict())
  65. in_cache = param_dict.get('in_cache', list())
  66. in_cache = self.prepare_cache(in_cache)
  67. try:
  68. t_offset = 0
  69. step = int(min(feats_len.max(), 6000))
  70. for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
  71. if t_offset + step >= feats_len - 1:
  72. step = feats_len - t_offset
  73. is_final = True
  74. else:
  75. is_final = False
  76. feats_package = feats[:, t_offset:int(t_offset + step), :]
  77. waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
  78. inputs = [feats_package]
  79. # inputs = [feats]
  80. inputs.extend(in_cache)
  81. scores, out_caches = self.infer(inputs)
  82. in_cache = out_caches
  83. segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
  84. # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
  85. if segments_part:
  86. for batch_num in range(0, self.batch_size):
  87. segments[batch_num] += segments_part[batch_num]
  88. except ONNXRuntimeError:
  89. # logging.warning(traceback.format_exc())
  90. logging.warning("input wav is silence or noise")
  91. segments = ''
  92. return segments
  93. def load_data(self,
  94. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  95. def load_wav(path: str) -> np.ndarray:
  96. waveform, _ = librosa.load(path, sr=fs)
  97. return waveform
  98. if isinstance(wav_content, np.ndarray):
  99. return [wav_content]
  100. if isinstance(wav_content, str):
  101. return [load_wav(wav_content)]
  102. if isinstance(wav_content, list):
  103. return [load_wav(path) for path in wav_content]
  104. raise TypeError(
  105. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  106. def extract_feat(self,
  107. waveform_list: List[np.ndarray]
  108. ) -> Tuple[np.ndarray, np.ndarray]:
  109. feats, feats_len = [], []
  110. for waveform in waveform_list:
  111. speech, _ = self.frontend.fbank(waveform)
  112. feat, feat_len = self.frontend.lfr_cmvn(speech)
  113. feats.append(feat)
  114. feats_len.append(feat_len)
  115. feats = self.pad_feats(feats, np.max(feats_len))
  116. feats_len = np.array(feats_len).astype(np.int32)
  117. return feats, feats_len
  118. @staticmethod
  119. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  120. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  121. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  122. return np.pad(feat, pad_width, 'constant', constant_values=0)
  123. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  124. feats = np.array(feat_res).astype(np.float32)
  125. return feats
  126. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  127. outputs = self.ort_infer(feats)
  128. scores, out_caches = outputs[0], outputs[1:]
  129. return scores, out_caches
  130. class Fsmn_vad_online():
  131. """
  132. Author: Speech Lab of DAMO Academy, Alibaba Group
  133. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  134. https://arxiv.org/abs/1803.05030
  135. """
  136. def __init__(self, model_dir: Union[str, Path] = None,
  137. batch_size: int = 1,
  138. device_id: Union[str, int] = "-1",
  139. quantize: bool = False,
  140. intra_op_num_threads: int = 4,
  141. max_end_sil: int = None,
  142. ):
  143. if not Path(model_dir).exists():
  144. raise FileNotFoundError(f'{model_dir} does not exist.')
  145. model_file = os.path.join(model_dir, 'model.onnx')
  146. if quantize:
  147. model_file = os.path.join(model_dir, 'model_quant.onnx')
  148. config_file = os.path.join(model_dir, 'vad.yaml')
  149. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  150. config = read_yaml(config_file)
  151. self.frontend = WavFrontendOnline(
  152. cmvn_file=cmvn_file,
  153. **config['frontend_conf']
  154. )
  155. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  156. self.batch_size = batch_size
  157. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  158. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  159. self.encoder_conf = config["encoder_conf"]
  160. def prepare_cache(self, in_cache: list = []):
  161. if len(in_cache) > 0:
  162. return in_cache
  163. fsmn_layers = self.encoder_conf["fsmn_layers"]
  164. proj_dim = self.encoder_conf["proj_dim"]
  165. lorder = self.encoder_conf["lorder"]
  166. for i in range(fsmn_layers):
  167. cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
  168. in_cache.append(cache)
  169. return in_cache
  170. def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
  171. waveforms = np.expand_dims(audio_in, axis=0)
  172. param_dict = kwargs.get('param_dict', dict())
  173. is_final = param_dict.get('is_final', False)
  174. feats, feats_len = self.extract_feat(waveforms, is_final)
  175. segments = []
  176. if feats.size != 0:
  177. in_cache = param_dict.get('in_cache', list())
  178. in_cache = self.prepare_cache(in_cache)
  179. try:
  180. inputs = [feats]
  181. inputs.extend(in_cache)
  182. scores, out_caches = self.infer(inputs)
  183. param_dict['in_cache'] = out_caches
  184. waveforms = self.frontend.get_waveforms()
  185. segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
  186. online=True)
  187. except ONNXRuntimeError:
  188. # logging.warning(traceback.format_exc())
  189. logging.warning("input wav is silence or noise")
  190. segments = []
  191. return segments
  192. def load_data(self,
  193. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  194. def load_wav(path: str) -> np.ndarray:
  195. waveform, _ = librosa.load(path, sr=fs)
  196. return waveform
  197. if isinstance(wav_content, np.ndarray):
  198. return [wav_content]
  199. if isinstance(wav_content, str):
  200. return [load_wav(wav_content)]
  201. if isinstance(wav_content, list):
  202. return [load_wav(path) for path in wav_content]
  203. raise TypeError(
  204. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  205. def extract_feat(self,
  206. waveforms: np.ndarray, is_final: bool = False
  207. ) -> Tuple[np.ndarray, np.ndarray]:
  208. waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
  209. for idx, waveform in enumerate(waveforms):
  210. waveforms_lens[idx] = waveform.shape[-1]
  211. feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
  212. # feats.append(feat)
  213. # feats_len.append(feat_len)
  214. # feats = self.pad_feats(feats, np.max(feats_len))
  215. # feats_len = np.array(feats_len).astype(np.int32)
  216. return feats.astype(np.float32), feats_len.astype(np.int32)
  217. @staticmethod
  218. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  219. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  220. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  221. return np.pad(feat, pad_width, 'constant', constant_values=0)
  222. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  223. feats = np.array(feat_res).astype(np.float32)
  224. return feats
  225. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  226. outputs = self.ort_infer(feats)
  227. scores, out_caches = outputs[0], outputs[1:]
  228. return scores, out_caches