vad_bin.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # -*- encoding: utf-8 -*-
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import os.path
  5. from pathlib import Path
  6. from typing import List, Union, Tuple
  7. import copy
  8. import librosa
  9. import numpy as np
  10. from .utils.utils import (ONNXRuntimeError,
  11. OrtInferSession, get_logger,
  12. read_yaml)
  13. from .utils.frontend import WavFrontend, WavFrontendOnline
  14. from .utils.e2e_vad import E2EVadModel
  15. logging = get_logger()
  16. class Fsmn_vad():
  17. """
  18. Author: Speech Lab of DAMO Academy, Alibaba Group
  19. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  20. https://arxiv.org/abs/1803.05030
  21. """
  22. def __init__(self, model_dir: Union[str, Path] = None,
  23. batch_size: int = 1,
  24. device_id: Union[str, int] = "-1",
  25. quantize: bool = False,
  26. intra_op_num_threads: int = 4,
  27. max_end_sil: int = None,
  28. ):
  29. if not Path(model_dir).exists():
  30. raise FileNotFoundError(f'{model_dir} does not exist.')
  31. model_file = os.path.join(model_dir, 'model.onnx')
  32. if quantize:
  33. model_file = os.path.join(model_dir, 'model_quant.onnx')
  34. config_file = os.path.join(model_dir, 'vad.yaml')
  35. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  36. config = read_yaml(config_file)
  37. self.frontend = WavFrontend(
  38. cmvn_file=cmvn_file,
  39. **config['frontend_conf']
  40. )
  41. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  42. self.batch_size = batch_size
  43. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  44. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  45. self.encoder_conf = config["encoder_conf"]
  46. def prepare_cache(self, in_cache: list = []):
  47. if len(in_cache) > 0:
  48. return in_cache
  49. fsmn_layers = self.encoder_conf["fsmn_layers"]
  50. proj_dim = self.encoder_conf["proj_dim"]
  51. lorder = self.encoder_conf["lorder"]
  52. for i in range(fsmn_layers):
  53. cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
  54. in_cache.append(cache)
  55. return in_cache
  56. def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
  57. waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
  58. waveform_nums = len(waveform_list)
  59. is_final = kwargs.get('kwargs', False)
  60. segments = [[]] * self.batch_size
  61. for beg_idx in range(0, waveform_nums, self.batch_size):
  62. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  63. waveform = waveform_list[beg_idx:end_idx]
  64. feats, feats_len = self.extract_feat(waveform)
  65. waveform = np.array(waveform)
  66. param_dict = kwargs.get('param_dict', dict())
  67. in_cache = param_dict.get('in_cache', list())
  68. in_cache = self.prepare_cache(in_cache)
  69. try:
  70. t_offset = 0
  71. step = int(min(feats_len.max(), 6000))
  72. for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
  73. if t_offset + step >= feats_len - 1:
  74. step = feats_len - t_offset
  75. is_final = True
  76. else:
  77. is_final = False
  78. feats_package = feats[:, t_offset:int(t_offset + step), :]
  79. waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
  80. inputs = [feats_package]
  81. # inputs = [feats]
  82. inputs.extend(in_cache)
  83. scores, out_caches = self.infer(inputs)
  84. in_cache = out_caches
  85. segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
  86. # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
  87. if segments_part:
  88. for batch_num in range(0, self.batch_size):
  89. segments[batch_num] += segments_part[batch_num]
  90. except ONNXRuntimeError:
  91. # logging.warning(traceback.format_exc())
  92. logging.warning("input wav is silence or noise")
  93. segments = ''
  94. return segments
  95. def load_data(self,
  96. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  97. def load_wav(path: str) -> np.ndarray:
  98. waveform, _ = librosa.load(path, sr=fs)
  99. return waveform
  100. if isinstance(wav_content, np.ndarray):
  101. return [wav_content]
  102. if isinstance(wav_content, str):
  103. return [load_wav(wav_content)]
  104. if isinstance(wav_content, list):
  105. return [load_wav(path) for path in wav_content]
  106. raise TypeError(
  107. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  108. def extract_feat(self,
  109. waveform_list: List[np.ndarray]
  110. ) -> Tuple[np.ndarray, np.ndarray]:
  111. feats, feats_len = [], []
  112. for waveform in waveform_list:
  113. speech, _ = self.frontend.fbank(waveform)
  114. feat, feat_len = self.frontend.lfr_cmvn(speech)
  115. feats.append(feat)
  116. feats_len.append(feat_len)
  117. feats = self.pad_feats(feats, np.max(feats_len))
  118. feats_len = np.array(feats_len).astype(np.int32)
  119. return feats, feats_len
  120. @staticmethod
  121. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  122. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  123. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  124. return np.pad(feat, pad_width, 'constant', constant_values=0)
  125. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  126. feats = np.array(feat_res).astype(np.float32)
  127. return feats
  128. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  129. outputs = self.ort_infer(feats)
  130. scores, out_caches = outputs[0], outputs[1:]
  131. return scores, out_caches
  132. class Fsmn_vad_online():
  133. """
  134. Author: Speech Lab of DAMO Academy, Alibaba Group
  135. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  136. https://arxiv.org/abs/1803.05030
  137. """
  138. def __init__(self, model_dir: Union[str, Path] = None,
  139. batch_size: int = 1,
  140. device_id: Union[str, int] = "-1",
  141. quantize: bool = False,
  142. intra_op_num_threads: int = 4,
  143. max_end_sil: int = None,
  144. ):
  145. if not Path(model_dir).exists():
  146. raise FileNotFoundError(f'{model_dir} does not exist.')
  147. model_file = os.path.join(model_dir, 'model.onnx')
  148. if quantize:
  149. model_file = os.path.join(model_dir, 'model_quant.onnx')
  150. config_file = os.path.join(model_dir, 'vad.yaml')
  151. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  152. config = read_yaml(config_file)
  153. self.frontend = WavFrontendOnline(
  154. cmvn_file=cmvn_file,
  155. **config['frontend_conf']
  156. )
  157. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  158. self.batch_size = batch_size
  159. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  160. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  161. self.encoder_conf = config["encoder_conf"]
  162. def prepare_cache(self, in_cache: list = []):
  163. if len(in_cache) > 0:
  164. return in_cache
  165. fsmn_layers = self.encoder_conf["fsmn_layers"]
  166. proj_dim = self.encoder_conf["proj_dim"]
  167. lorder = self.encoder_conf["lorder"]
  168. for i in range(fsmn_layers):
  169. cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
  170. in_cache.append(cache)
  171. return in_cache
  172. def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
  173. waveforms = np.expand_dims(audio_in, axis=0)
  174. param_dict = kwargs.get('param_dict', dict())
  175. is_final = param_dict.get('is_final', False)
  176. feats, feats_len = self.extract_feat(waveforms, is_final)
  177. segments = []
  178. if feats.size != 0:
  179. in_cache = param_dict.get('in_cache', list())
  180. in_cache = self.prepare_cache(in_cache)
  181. try:
  182. inputs = [feats]
  183. inputs.extend(in_cache)
  184. scores, out_caches = self.infer(inputs)
  185. param_dict['in_cache'] = out_caches
  186. waveforms = self.frontend.get_waveforms()
  187. segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
  188. online=True)
  189. except ONNXRuntimeError:
  190. # logging.warning(traceback.format_exc())
  191. logging.warning("input wav is silence or noise")
  192. segments = []
  193. return segments
  194. def load_data(self,
  195. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  196. def load_wav(path: str) -> np.ndarray:
  197. waveform, _ = librosa.load(path, sr=fs)
  198. return waveform
  199. if isinstance(wav_content, np.ndarray):
  200. return [wav_content]
  201. if isinstance(wav_content, str):
  202. return [load_wav(wav_content)]
  203. if isinstance(wav_content, list):
  204. return [load_wav(path) for path in wav_content]
  205. raise TypeError(
  206. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  207. def extract_feat(self,
  208. waveforms: np.ndarray, is_final: bool = False
  209. ) -> Tuple[np.ndarray, np.ndarray]:
  210. waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
  211. for idx, waveform in enumerate(waveforms):
  212. waveforms_lens[idx] = waveform.shape[-1]
  213. feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
  214. # feats.append(feat)
  215. # feats_len.append(feat_len)
  216. # feats = self.pad_feats(feats, np.max(feats_len))
  217. # feats_len = np.array(feats_len).astype(np.int32)
  218. return feats.astype(np.float32), feats_len.astype(np.int32)
  219. @staticmethod
  220. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  221. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  222. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  223. return np.pad(feat, pad_width, 'constant', constant_values=0)
  224. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  225. feats = np.array(feat_res).astype(np.float32)
  226. return feats
  227. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  228. outputs = self.ort_infer(feats)
  229. scores, out_caches = outputs[0], outputs[1:]
  230. return scores, out_caches