vad_bin.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # -*- encoding: utf-8 -*-
  2. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  3. # MIT License (https://opensource.org/licenses/MIT)
  4. import os.path
  5. from pathlib import Path
  6. from typing import List, Union, Tuple
  7. import copy
  8. import librosa
  9. import numpy as np
  10. from .utils.utils import (ONNXRuntimeError,
  11. OrtInferSession, get_logger,
  12. read_yaml)
  13. from .utils.frontend import WavFrontend, WavFrontendOnline
  14. from .utils.e2e_vad import E2EVadModel
  15. logging = get_logger()
  16. class Fsmn_vad():
  17. """
  18. Author: Speech Lab of DAMO Academy, Alibaba Group
  19. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  20. https://arxiv.org/abs/1803.05030
  21. """
  22. def __init__(self, model_dir: Union[str, Path] = None,
  23. batch_size: int = 1,
  24. device_id: Union[str, int] = "-1",
  25. quantize: bool = False,
  26. intra_op_num_threads: int = 4,
  27. max_end_sil: int = None,
  28. cache_dir: str = None
  29. ):
  30. if not Path(model_dir).exists():
  31. try:
  32. from modelscope.hub.snapshot_download import snapshot_download
  33. except:
  34. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  35. "\npip3 install -U modelscope\n" \
  36. "For the users in China, you could install with the command:\n" \
  37. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  38. try:
  39. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  40. except:
  41. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
  42. model_dir)
  43. model_file = os.path.join(model_dir, 'model.onnx')
  44. if quantize:
  45. model_file = os.path.join(model_dir, 'model_quant.onnx')
  46. if not os.path.exists(model_file):
  47. print(".onnx is not exist, begin to export onnx")
  48. try:
  49. from funasr.export.export_model import ModelExport
  50. except:
  51. raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
  52. "\npip3 install -U funasr\n" \
  53. "For the users in China, you could install with the command:\n" \
  54. "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  55. export_model = ModelExport(
  56. cache_dir=cache_dir,
  57. onnx=True,
  58. device="cpu",
  59. quant=quantize,
  60. )
  61. export_model.export(model_dir)
  62. config_file = os.path.join(model_dir, 'vad.yaml')
  63. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  64. config = read_yaml(config_file)
  65. self.frontend = WavFrontend(
  66. cmvn_file=cmvn_file,
  67. **config['frontend_conf']
  68. )
  69. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  70. self.batch_size = batch_size
  71. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  72. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  73. self.encoder_conf = config["encoder_conf"]
  74. def prepare_cache(self, in_cache: list = []):
  75. if len(in_cache) > 0:
  76. return in_cache
  77. fsmn_layers = self.encoder_conf["fsmn_layers"]
  78. proj_dim = self.encoder_conf["proj_dim"]
  79. lorder = self.encoder_conf["lorder"]
  80. for i in range(fsmn_layers):
  81. cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
  82. in_cache.append(cache)
  83. return in_cache
  84. def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
  85. waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
  86. waveform_nums = len(waveform_list)
  87. is_final = kwargs.get('kwargs', False)
  88. segments = [[]] * self.batch_size
  89. for beg_idx in range(0, waveform_nums, self.batch_size):
  90. end_idx = min(waveform_nums, beg_idx + self.batch_size)
  91. waveform = waveform_list[beg_idx:end_idx]
  92. feats, feats_len = self.extract_feat(waveform)
  93. waveform = np.array(waveform)
  94. param_dict = kwargs.get('param_dict', dict())
  95. in_cache = param_dict.get('in_cache', list())
  96. in_cache = self.prepare_cache(in_cache)
  97. try:
  98. t_offset = 0
  99. step = int(min(feats_len.max(), 6000))
  100. for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
  101. if t_offset + step >= feats_len - 1:
  102. step = feats_len - t_offset
  103. is_final = True
  104. else:
  105. is_final = False
  106. feats_package = feats[:, t_offset:int(t_offset + step), :]
  107. waveform_package = waveform[:, t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
  108. inputs = [feats_package]
  109. # inputs = [feats]
  110. inputs.extend(in_cache)
  111. scores, out_caches = self.infer(inputs)
  112. in_cache = out_caches
  113. segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final, max_end_sil=self.max_end_sil, online=False)
  114. # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
  115. if segments_part:
  116. for batch_num in range(0, self.batch_size):
  117. segments[batch_num] += segments_part[batch_num]
  118. except ONNXRuntimeError:
  119. # logging.warning(traceback.format_exc())
  120. logging.warning("input wav is silence or noise")
  121. segments = ''
  122. return segments
  123. def load_data(self,
  124. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  125. def load_wav(path: str) -> np.ndarray:
  126. waveform, _ = librosa.load(path, sr=fs)
  127. return waveform
  128. if isinstance(wav_content, np.ndarray):
  129. return [wav_content]
  130. if isinstance(wav_content, str):
  131. return [load_wav(wav_content)]
  132. if isinstance(wav_content, list):
  133. return [load_wav(path) for path in wav_content]
  134. raise TypeError(
  135. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  136. def extract_feat(self,
  137. waveform_list: List[np.ndarray]
  138. ) -> Tuple[np.ndarray, np.ndarray]:
  139. feats, feats_len = [], []
  140. for waveform in waveform_list:
  141. speech, _ = self.frontend.fbank(waveform)
  142. feat, feat_len = self.frontend.lfr_cmvn(speech)
  143. feats.append(feat)
  144. feats_len.append(feat_len)
  145. feats = self.pad_feats(feats, np.max(feats_len))
  146. feats_len = np.array(feats_len).astype(np.int32)
  147. return feats, feats_len
  148. @staticmethod
  149. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  150. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  151. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  152. return np.pad(feat, pad_width, 'constant', constant_values=0)
  153. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  154. feats = np.array(feat_res).astype(np.float32)
  155. return feats
  156. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  157. outputs = self.ort_infer(feats)
  158. scores, out_caches = outputs[0], outputs[1:]
  159. return scores, out_caches
  160. class Fsmn_vad_online():
  161. """
  162. Author: Speech Lab of DAMO Academy, Alibaba Group
  163. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  164. https://arxiv.org/abs/1803.05030
  165. """
  166. def __init__(self, model_dir: Union[str, Path] = None,
  167. batch_size: int = 1,
  168. device_id: Union[str, int] = "-1",
  169. quantize: bool = False,
  170. intra_op_num_threads: int = 4,
  171. max_end_sil: int = None,
  172. cache_dir: str = None
  173. ):
  174. if not Path(model_dir).exists():
  175. try:
  176. from modelscope.hub.snapshot_download import snapshot_download
  177. except:
  178. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  179. "\npip3 install -U modelscope\n" \
  180. "For the users in China, you could install with the command:\n" \
  181. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  182. try:
  183. model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
  184. except:
  185. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
  186. model_dir)
  187. model_file = os.path.join(model_dir, 'model.onnx')
  188. if quantize:
  189. model_file = os.path.join(model_dir, 'model_quant.onnx')
  190. if not os.path.exists(model_file):
  191. print(".onnx is not exist, begin to export onnx")
  192. try:
  193. from funasr.export.export_model import ModelExport
  194. except:
  195. raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
  196. "\npip3 install -U funasr\n" \
  197. "For the users in China, you could install with the command:\n" \
  198. "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  199. export_model = ModelExport(
  200. cache_dir=cache_dir,
  201. onnx=True,
  202. device="cpu",
  203. quant=quantize,
  204. )
  205. export_model.export(model_dir)
  206. config_file = os.path.join(model_dir, 'vad.yaml')
  207. cmvn_file = os.path.join(model_dir, 'vad.mvn')
  208. config = read_yaml(config_file)
  209. self.frontend = WavFrontendOnline(
  210. cmvn_file=cmvn_file,
  211. **config['frontend_conf']
  212. )
  213. self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
  214. self.batch_size = batch_size
  215. self.vad_scorer = E2EVadModel(config["vad_post_conf"])
  216. self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
  217. self.encoder_conf = config["encoder_conf"]
  218. def prepare_cache(self, in_cache: list = []):
  219. if len(in_cache) > 0:
  220. return in_cache
  221. fsmn_layers = self.encoder_conf["fsmn_layers"]
  222. proj_dim = self.encoder_conf["proj_dim"]
  223. lorder = self.encoder_conf["lorder"]
  224. for i in range(fsmn_layers):
  225. cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
  226. in_cache.append(cache)
  227. return in_cache
  228. def __call__(self, audio_in: np.ndarray, **kwargs) -> List:
  229. waveforms = np.expand_dims(audio_in, axis=0)
  230. param_dict = kwargs.get('param_dict', dict())
  231. is_final = param_dict.get('is_final', False)
  232. feats, feats_len = self.extract_feat(waveforms, is_final)
  233. segments = []
  234. if feats.size != 0:
  235. in_cache = param_dict.get('in_cache', list())
  236. in_cache = self.prepare_cache(in_cache)
  237. try:
  238. inputs = [feats]
  239. inputs.extend(in_cache)
  240. scores, out_caches = self.infer(inputs)
  241. param_dict['in_cache'] = out_caches
  242. waveforms = self.frontend.get_waveforms()
  243. segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil,
  244. online=True)
  245. except ONNXRuntimeError:
  246. # logging.warning(traceback.format_exc())
  247. logging.warning("input wav is silence or noise")
  248. segments = []
  249. return segments
  250. def load_data(self,
  251. wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
  252. def load_wav(path: str) -> np.ndarray:
  253. waveform, _ = librosa.load(path, sr=fs)
  254. return waveform
  255. if isinstance(wav_content, np.ndarray):
  256. return [wav_content]
  257. if isinstance(wav_content, str):
  258. return [load_wav(wav_content)]
  259. if isinstance(wav_content, list):
  260. return [load_wav(path) for path in wav_content]
  261. raise TypeError(
  262. f'The type of {wav_content} is not in [str, np.ndarray, list]')
  263. def extract_feat(self,
  264. waveforms: np.ndarray, is_final: bool = False
  265. ) -> Tuple[np.ndarray, np.ndarray]:
  266. waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
  267. for idx, waveform in enumerate(waveforms):
  268. waveforms_lens[idx] = waveform.shape[-1]
  269. feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
  270. # feats.append(feat)
  271. # feats_len.append(feat_len)
  272. # feats = self.pad_feats(feats, np.max(feats_len))
  273. # feats_len = np.array(feats_len).astype(np.int32)
  274. return feats.astype(np.float32), feats_len.astype(np.int32)
  275. @staticmethod
  276. def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
  277. def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
  278. pad_width = ((0, max_feat_len - cur_len), (0, 0))
  279. return np.pad(feat, pad_width, 'constant', constant_values=0)
  280. feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
  281. feats = np.array(feat_res).astype(np.float32)
  282. return feats
  283. def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
  284. outputs = self.ort_infer(feats)
  285. scores, out_caches = outputs[0], outputs[1:]
  286. return scores, out_caches