model.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. import json
  7. import time
  8. import math
  9. import torch
  10. from torch import nn
  11. from enum import Enum
  12. from dataclasses import dataclass
  13. from funasr.register import tables
  14. from typing import List, Tuple, Dict, Any, Optional
  15. from funasr.utils.datadir_writer import DatadirWriter
  16. from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
  17. class VadStateMachine(Enum):
  18. kVadInStateStartPointNotDetected = 1
  19. kVadInStateInSpeechSegment = 2
  20. kVadInStateEndPointDetected = 3
  21. class FrameState(Enum):
  22. kFrameStateInvalid = -1
  23. kFrameStateSpeech = 1
  24. kFrameStateSil = 0
  25. # final voice/unvoice state per frame
  26. class AudioChangeState(Enum):
  27. kChangeStateSpeech2Speech = 0
  28. kChangeStateSpeech2Sil = 1
  29. kChangeStateSil2Sil = 2
  30. kChangeStateSil2Speech = 3
  31. kChangeStateNoBegin = 4
  32. kChangeStateInvalid = 5
  33. class VadDetectMode(Enum):
  34. kVadSingleUtteranceDetectMode = 0
  35. kVadMutipleUtteranceDetectMode = 1
  36. class VADXOptions:
  37. """
  38. Author: Speech Lab of DAMO Academy, Alibaba Group
  39. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  40. https://arxiv.org/abs/1803.05030
  41. """
  42. def __init__(
  43. self,
  44. sample_rate: int = 16000,
  45. detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
  46. snr_mode: int = 0,
  47. max_end_silence_time: int = 800,
  48. max_start_silence_time: int = 3000,
  49. do_start_point_detection: bool = True,
  50. do_end_point_detection: bool = True,
  51. window_size_ms: int = 200,
  52. sil_to_speech_time_thres: int = 150,
  53. speech_to_sil_time_thres: int = 150,
  54. speech_2_noise_ratio: float = 1.0,
  55. do_extend: int = 1,
  56. lookback_time_start_point: int = 200,
  57. lookahead_time_end_point: int = 100,
  58. max_single_segment_time: int = 60000,
  59. nn_eval_block_size: int = 8,
  60. dcd_block_size: int = 4,
  61. snr_thres: int = -100.0,
  62. noise_frame_num_used_for_snr: int = 100,
  63. decibel_thres: int = -100.0,
  64. speech_noise_thres: float = 0.6,
  65. fe_prior_thres: float = 1e-4,
  66. silence_pdf_num: int = 1,
  67. sil_pdf_ids: List[int] = [0],
  68. speech_noise_thresh_low: float = -0.1,
  69. speech_noise_thresh_high: float = 0.3,
  70. output_frame_probs: bool = False,
  71. frame_in_ms: int = 10,
  72. frame_length_ms: int = 25,
  73. **kwargs,
  74. ):
  75. self.sample_rate = sample_rate
  76. self.detect_mode = detect_mode
  77. self.snr_mode = snr_mode
  78. self.max_end_silence_time = max_end_silence_time
  79. self.max_start_silence_time = max_start_silence_time
  80. self.do_start_point_detection = do_start_point_detection
  81. self.do_end_point_detection = do_end_point_detection
  82. self.window_size_ms = window_size_ms
  83. self.sil_to_speech_time_thres = sil_to_speech_time_thres
  84. self.speech_to_sil_time_thres = speech_to_sil_time_thres
  85. self.speech_2_noise_ratio = speech_2_noise_ratio
  86. self.do_extend = do_extend
  87. self.lookback_time_start_point = lookback_time_start_point
  88. self.lookahead_time_end_point = lookahead_time_end_point
  89. self.max_single_segment_time = max_single_segment_time
  90. self.nn_eval_block_size = nn_eval_block_size
  91. self.dcd_block_size = dcd_block_size
  92. self.snr_thres = snr_thres
  93. self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
  94. self.decibel_thres = decibel_thres
  95. self.speech_noise_thres = speech_noise_thres
  96. self.fe_prior_thres = fe_prior_thres
  97. self.silence_pdf_num = silence_pdf_num
  98. self.sil_pdf_ids = sil_pdf_ids
  99. self.speech_noise_thresh_low = speech_noise_thresh_low
  100. self.speech_noise_thresh_high = speech_noise_thresh_high
  101. self.output_frame_probs = output_frame_probs
  102. self.frame_in_ms = frame_in_ms
  103. self.frame_length_ms = frame_length_ms
  104. class E2EVadSpeechBufWithDoa(object):
  105. """
  106. Author: Speech Lab of DAMO Academy, Alibaba Group
  107. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  108. https://arxiv.org/abs/1803.05030
  109. """
  110. def __init__(self):
  111. self.start_ms = 0
  112. self.end_ms = 0
  113. self.buffer = []
  114. self.contain_seg_start_point = False
  115. self.contain_seg_end_point = False
  116. self.doa = 0
  117. def Reset(self):
  118. self.start_ms = 0
  119. self.end_ms = 0
  120. self.buffer = []
  121. self.contain_seg_start_point = False
  122. self.contain_seg_end_point = False
  123. self.doa = 0
  124. class E2EVadFrameProb(object):
  125. """
  126. Author: Speech Lab of DAMO Academy, Alibaba Group
  127. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  128. https://arxiv.org/abs/1803.05030
  129. """
  130. def __init__(self):
  131. self.noise_prob = 0.0
  132. self.speech_prob = 0.0
  133. self.score = 0.0
  134. self.frame_id = 0
  135. self.frm_state = 0
  136. class WindowDetector(object):
  137. """
  138. Author: Speech Lab of DAMO Academy, Alibaba Group
  139. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  140. https://arxiv.org/abs/1803.05030
  141. """
  142. def __init__(self, window_size_ms: int,
  143. sil_to_speech_time: int,
  144. speech_to_sil_time: int,
  145. frame_size_ms: int):
  146. self.window_size_ms = window_size_ms
  147. self.sil_to_speech_time = sil_to_speech_time
  148. self.speech_to_sil_time = speech_to_sil_time
  149. self.frame_size_ms = frame_size_ms
  150. self.win_size_frame = int(window_size_ms / frame_size_ms)
  151. self.win_sum = 0
  152. self.win_state = [0] * self.win_size_frame # 初始化窗
  153. self.cur_win_pos = 0
  154. self.pre_frame_state = FrameState.kFrameStateSil
  155. self.cur_frame_state = FrameState.kFrameStateSil
  156. self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
  157. self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
  158. self.voice_last_frame_count = 0
  159. self.noise_last_frame_count = 0
  160. self.hydre_frame_count = 0
  161. def Reset(self) -> None:
  162. self.cur_win_pos = 0
  163. self.win_sum = 0
  164. self.win_state = [0] * self.win_size_frame
  165. self.pre_frame_state = FrameState.kFrameStateSil
  166. self.cur_frame_state = FrameState.kFrameStateSil
  167. self.voice_last_frame_count = 0
  168. self.noise_last_frame_count = 0
  169. self.hydre_frame_count = 0
  170. def GetWinSize(self) -> int:
  171. return int(self.win_size_frame)
  172. def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
  173. cur_frame_state = FrameState.kFrameStateSil
  174. if frameState == FrameState.kFrameStateSpeech:
  175. cur_frame_state = 1
  176. elif frameState == FrameState.kFrameStateSil:
  177. cur_frame_state = 0
  178. else:
  179. return AudioChangeState.kChangeStateInvalid
  180. self.win_sum -= self.win_state[self.cur_win_pos]
  181. self.win_sum += cur_frame_state
  182. self.win_state[self.cur_win_pos] = cur_frame_state
  183. self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
  184. if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
  185. self.pre_frame_state = FrameState.kFrameStateSpeech
  186. return AudioChangeState.kChangeStateSil2Speech
  187. if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
  188. self.pre_frame_state = FrameState.kFrameStateSil
  189. return AudioChangeState.kChangeStateSpeech2Sil
  190. if self.pre_frame_state == FrameState.kFrameStateSil:
  191. return AudioChangeState.kChangeStateSil2Sil
  192. if self.pre_frame_state == FrameState.kFrameStateSpeech:
  193. return AudioChangeState.kChangeStateSpeech2Speech
  194. return AudioChangeState.kChangeStateInvalid
  195. def FrameSizeMs(self) -> int:
  196. return int(self.frame_size_ms)
  197. class Stats(object):
  198. def __init__(self,
  199. sil_pdf_ids,
  200. max_end_sil_frame_cnt_thresh,
  201. speech_noise_thres,
  202. ):
  203. self.data_buf_start_frame = 0
  204. self.frm_cnt = 0
  205. self.latest_confirmed_speech_frame = 0
  206. self.lastest_confirmed_silence_frame = -1
  207. self.continous_silence_frame_count = 0
  208. self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
  209. self.confirmed_start_frame = -1
  210. self.confirmed_end_frame = -1
  211. self.number_end_time_detected = 0
  212. self.sil_frame = 0
  213. self.sil_pdf_ids = sil_pdf_ids
  214. self.noise_average_decibel = -100.0
  215. self.pre_end_silence_detected = False
  216. self.next_seg = True
  217. self.output_data_buf = []
  218. self.output_data_buf_offset = 0
  219. self.frame_probs = []
  220. self.max_end_sil_frame_cnt_thresh = max_end_sil_frame_cnt_thresh
  221. self.speech_noise_thres = speech_noise_thres
  222. self.scores = None
  223. self.max_time_out = False
  224. self.decibel = []
  225. self.data_buf = None
  226. self.data_buf_all = None
  227. self.waveform = None
  228. self.last_drop_frames = 0
  229. @tables.register("model_classes", "FsmnVADStreaming")
  230. class FsmnVADStreaming(nn.Module):
  231. """
  232. Author: Speech Lab of DAMO Academy, Alibaba Group
  233. Deep-FSMN for Large Vocabulary Continuous Speech Recognition
  234. https://arxiv.org/abs/1803.05030
  235. """
  236. def __init__(self,
  237. encoder: str = None,
  238. encoder_conf: Optional[Dict] = None,
  239. vad_post_args: Dict[str, Any] = None,
  240. **kwargs,
  241. ):
  242. super().__init__()
  243. self.vad_opts = VADXOptions(**kwargs)
  244. encoder_class = tables.encoder_classes.get(encoder)
  245. encoder = encoder_class(**encoder_conf)
  246. self.encoder = encoder
  247. def ResetDetection(self, cache: dict = {}):
  248. cache["stats"].continous_silence_frame_count = 0
  249. cache["stats"].latest_confirmed_speech_frame = 0
  250. cache["stats"].lastest_confirmed_silence_frame = -1
  251. cache["stats"].confirmed_start_frame = -1
  252. cache["stats"].confirmed_end_frame = -1
  253. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
  254. cache["windows_detector"].Reset()
  255. cache["stats"].sil_frame = 0
  256. cache["stats"].frame_probs = []
  257. if cache["stats"].output_data_buf:
  258. assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
  259. drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
  260. real_drop_frames = drop_frames - cache["stats"].last_drop_frames
  261. cache["stats"].last_drop_frames = drop_frames
  262. cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
  263. cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
  264. cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
  265. def ComputeDecibel(self, cache: dict = {}) -> None:
  266. frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
  267. frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
  268. if cache["stats"].data_buf_all is None:
  269. cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
  270. cache["stats"].data_buf = cache["stats"].data_buf_all
  271. else:
  272. cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
  273. for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
  274. cache["stats"].decibel.append(
  275. 10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
  276. 0.000001))
  277. def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
  278. scores = self.encoder(feats, cache=cache["encoder"]).to('cpu') # return B * T * D
  279. assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
  280. self.vad_opts.nn_eval_block_size = scores.shape[1]
  281. cache["stats"].frm_cnt += scores.shape[1] # count total frames
  282. if cache["stats"].scores is None:
  283. cache["stats"].scores = scores # the first calculation
  284. else:
  285. cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
  286. def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again
  287. while cache["stats"].data_buf_start_frame < frame_idx:
  288. if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
  289. cache["stats"].data_buf_start_frame += 1
  290. cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
  291. self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
  292. def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
  293. last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
  294. self.PopDataBufTillFrame(start_frm, cache=cache)
  295. expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
  296. if last_frm_is_end_point:
  297. extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
  298. self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
  299. expected_sample_number += int(extra_sample)
  300. if end_point_is_sent_end:
  301. expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
  302. if len(cache["stats"].data_buf) < expected_sample_number:
  303. print('error in calling pop data_buf\n')
  304. if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
  305. cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
  306. cache["stats"].output_data_buf[-1].Reset()
  307. cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
  308. cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
  309. cache["stats"].output_data_buf[-1].doa = 0
  310. cur_seg = cache["stats"].output_data_buf[-1]
  311. if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
  312. print('warning\n')
  313. out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
  314. data_to_pop = 0
  315. if end_point_is_sent_end:
  316. data_to_pop = expected_sample_number
  317. else:
  318. data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
  319. if data_to_pop > len(cache["stats"].data_buf):
  320. print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
  321. data_to_pop = len(cache["stats"].data_buf)
  322. expected_sample_number = len(cache["stats"].data_buf)
  323. cur_seg.doa = 0
  324. for sample_cpy_out in range(0, data_to_pop):
  325. # cur_seg.buffer[out_pos ++] = data_buf_.back();
  326. out_pos += 1
  327. for sample_cpy_out in range(data_to_pop, expected_sample_number):
  328. # cur_seg.buffer[out_pos++] = data_buf_.back()
  329. out_pos += 1
  330. if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
  331. print('Something wrong with the VAD algorithm\n')
  332. cache["stats"].data_buf_start_frame += frm_cnt
  333. cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
  334. if first_frm_is_start_point:
  335. cur_seg.contain_seg_start_point = True
  336. if last_frm_is_end_point:
  337. cur_seg.contain_seg_end_point = True
  338. def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
  339. cache["stats"].lastest_confirmed_silence_frame = valid_frame
  340. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
  341. self.PopDataBufTillFrame(valid_frame, cache=cache)
  342. # silence_detected_callback_
  343. # pass
  344. def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
  345. cache["stats"].latest_confirmed_speech_frame = valid_frame
  346. self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
  347. def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
  348. if self.vad_opts.do_start_point_detection:
  349. pass
  350. if cache["stats"].confirmed_start_frame != -1:
  351. print('not reset vad properly\n')
  352. else:
  353. cache["stats"].confirmed_start_frame = start_frame
  354. if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
  355. self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
  356. def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
  357. for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
  358. self.OnVoiceDetected(t, cache=cache)
  359. if self.vad_opts.do_end_point_detection:
  360. pass
  361. if cache["stats"].confirmed_end_frame != -1:
  362. print('not reset vad properly\n')
  363. else:
  364. cache["stats"].confirmed_end_frame = end_frame
  365. if not fake_result:
  366. cache["stats"].sil_frame = 0
  367. self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
  368. cache["stats"].number_end_time_detected += 1
  369. def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
  370. if is_final_frame:
  371. self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
  372. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  373. def GetLatency(self, cache: dict = {}) -> int:
  374. return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
  375. def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
  376. vad_latency = cache["windows_detector"].GetWinSize()
  377. if self.vad_opts.do_extend:
  378. vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
  379. return vad_latency
  380. def GetFrameState(self, t: int, cache: dict = {}):
  381. frame_state = FrameState.kFrameStateInvalid
  382. cur_decibel = cache["stats"].decibel[t]
  383. cur_snr = cur_decibel - cache["stats"].noise_average_decibel
  384. # for each frame, calc log posterior probability of each state
  385. if cur_decibel < self.vad_opts.decibel_thres:
  386. frame_state = FrameState.kFrameStateSil
  387. self.DetectOneFrame(frame_state, t, False, cache=cache)
  388. return frame_state
  389. sum_score = 0.0
  390. noise_prob = 0.0
  391. assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
  392. if len(cache["stats"].sil_pdf_ids) > 0:
  393. assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试
  394. sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
  395. sum_score = sum(sil_pdf_scores)
  396. noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
  397. total_score = 1.0
  398. sum_score = total_score - sum_score
  399. speech_prob = math.log(sum_score)
  400. if self.vad_opts.output_frame_probs:
  401. frame_prob = E2EVadFrameProb()
  402. frame_prob.noise_prob = noise_prob
  403. frame_prob.speech_prob = speech_prob
  404. frame_prob.score = sum_score
  405. frame_prob.frame_id = t
  406. cache["stats"].frame_probs.append(frame_prob)
  407. if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
  408. if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
  409. frame_state = FrameState.kFrameStateSpeech
  410. else:
  411. frame_state = FrameState.kFrameStateSil
  412. else:
  413. frame_state = FrameState.kFrameStateSil
  414. if cache["stats"].noise_average_decibel < -99.9:
  415. cache["stats"].noise_average_decibel = cur_decibel
  416. else:
  417. cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
  418. self.vad_opts.noise_frame_num_used_for_snr
  419. - 1)) / self.vad_opts.noise_frame_num_used_for_snr
  420. return frame_state
  421. def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
  422. is_final: bool = False
  423. ):
  424. # if len(cache) == 0:
  425. # self.AllResetDetection()
  426. # self.waveform = waveform # compute decibel for each frame
  427. cache["stats"].waveform = waveform
  428. self.ComputeDecibel(cache=cache)
  429. self.ComputeScores(feats, cache=cache)
  430. if not is_final:
  431. self.DetectCommonFrames(cache=cache)
  432. else:
  433. self.DetectLastFrames(cache=cache)
  434. segments = []
  435. for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
  436. segment_batch = []
  437. if len(cache["stats"].output_data_buf) > 0:
  438. for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
  439. if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
  440. i].contain_seg_end_point):
  441. continue
  442. segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
  443. segment_batch.append(segment)
  444. cache["stats"].output_data_buf_offset += 1 # need update this parameter
  445. if segment_batch:
  446. segments.append(segment_batch)
  447. # if is_final:
  448. # # reset class variables and clear the dict for the next query
  449. # self.AllResetDetection()
  450. return segments
  451. def init_cache(self, cache: dict = {}, **kwargs):
  452. cache["frontend"] = {}
  453. cache["prev_samples"] = torch.empty(0)
  454. cache["encoder"] = {}
  455. windows_detector = WindowDetector(self.vad_opts.window_size_ms,
  456. self.vad_opts.sil_to_speech_time_thres,
  457. self.vad_opts.speech_to_sil_time_thres,
  458. self.vad_opts.frame_in_ms)
  459. windows_detector.Reset()
  460. stats = Stats(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
  461. max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
  462. speech_noise_thres=self.vad_opts.speech_noise_thres
  463. )
  464. cache["windows_detector"] = windows_detector
  465. cache["stats"] = stats
  466. return cache
  467. def inference(self,
  468. data_in,
  469. data_lengths=None,
  470. key: list = None,
  471. tokenizer=None,
  472. frontend=None,
  473. cache: dict = {},
  474. **kwargs,
  475. ):
  476. if len(cache) == 0:
  477. self.init_cache(cache, **kwargs)
  478. meta_data = {}
  479. chunk_size = kwargs.get("chunk_size", 60000) # 50ms
  480. chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
  481. time1 = time.perf_counter()
  482. cfg = {"is_final": kwargs.get("is_final", False)}
  483. audio_sample_list = load_audio_text_image_video(data_in,
  484. fs=frontend.fs,
  485. audio_fs=kwargs.get("fs", 16000),
  486. data_type=kwargs.get("data_type", "sound"),
  487. tokenizer=tokenizer,
  488. cache=cfg,
  489. )
  490. _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
  491. time2 = time.perf_counter()
  492. meta_data["load_data"] = f"{time2 - time1:0.3f}"
  493. assert len(audio_sample_list) == 1, "batch_size must be set 1"
  494. audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
  495. n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
  496. m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
  497. segments = []
  498. for i in range(n):
  499. kwargs["is_final"] = _is_final and i == n - 1
  500. audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
  501. # extract fbank feats
  502. speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
  503. frontend=frontend, cache=cache["frontend"],
  504. is_final=kwargs["is_final"])
  505. time3 = time.perf_counter()
  506. meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
  507. meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
  508. speech = speech.to(device=kwargs["device"])
  509. speech_lengths = speech_lengths.to(device=kwargs["device"])
  510. batch = {
  511. "feats": speech,
  512. "waveform": cache["frontend"]["waveforms"],
  513. "is_final": kwargs["is_final"],
  514. "cache": cache
  515. }
  516. segments_i = self.forward(**batch)
  517. if len(segments_i) > 0:
  518. segments.extend(*segments_i)
  519. cache["prev_samples"] = audio_sample[:-m]
  520. if _is_final:
  521. self.init_cache(cache)
  522. ibest_writer = None
  523. if ibest_writer is None and kwargs.get("output_dir") is not None:
  524. writer = DatadirWriter(kwargs.get("output_dir"))
  525. ibest_writer = writer[f"{1}best_recog"]
  526. results = []
  527. result_i = {"key": key[0], "value": segments}
  528. if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
  529. result_i = json.dumps(result_i)
  530. results.append(result_i)
  531. if ibest_writer is not None:
  532. ibest_writer["text"][key[0]] = segments
  533. return results, meta_data
  534. def DetectCommonFrames(self, cache: dict = {}) -> int:
  535. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
  536. return 0
  537. for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
  538. frame_state = FrameState.kFrameStateInvalid
  539. frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
  540. self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
  541. return 0
  542. def DetectLastFrames(self, cache: dict = {}) -> int:
  543. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
  544. return 0
  545. for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
  546. frame_state = FrameState.kFrameStateInvalid
  547. frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
  548. if i != 0:
  549. self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
  550. else:
  551. self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
  552. return 0
  553. def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
  554. tmp_cur_frm_state = FrameState.kFrameStateInvalid
  555. if cur_frm_state == FrameState.kFrameStateSpeech:
  556. if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
  557. tmp_cur_frm_state = FrameState.kFrameStateSpeech
  558. else:
  559. tmp_cur_frm_state = FrameState.kFrameStateSil
  560. elif cur_frm_state == FrameState.kFrameStateSil:
  561. tmp_cur_frm_state = FrameState.kFrameStateSil
  562. state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
  563. frm_shift_in_ms = self.vad_opts.frame_in_ms
  564. if AudioChangeState.kChangeStateSil2Speech == state_change:
  565. silence_frame_count = cache["stats"].continous_silence_frame_count
  566. cache["stats"].continous_silence_frame_count = 0
  567. cache["stats"].pre_end_silence_detected = False
  568. start_frame = 0
  569. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
  570. start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
  571. self.OnVoiceStart(start_frame, cache=cache)
  572. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
  573. for t in range(start_frame + 1, cur_frm_idx + 1):
  574. self.OnVoiceDetected(t, cache=cache)
  575. elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
  576. for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
  577. self.OnVoiceDetected(t, cache=cache)
  578. if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
  579. self.vad_opts.max_single_segment_time / frm_shift_in_ms:
  580. self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
  581. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  582. elif not is_final_frame:
  583. self.OnVoiceDetected(cur_frm_idx, cache=cache)
  584. else:
  585. self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
  586. else:
  587. pass
  588. elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
  589. cache["stats"].continous_silence_frame_count = 0
  590. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
  591. pass
  592. elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
  593. if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
  594. self.vad_opts.max_single_segment_time / frm_shift_in_ms:
  595. self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
  596. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  597. elif not is_final_frame:
  598. self.OnVoiceDetected(cur_frm_idx, cache=cache)
  599. else:
  600. self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
  601. else:
  602. pass
  603. elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
  604. cache["stats"].continous_silence_frame_count = 0
  605. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
  606. if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
  607. self.vad_opts.max_single_segment_time / frm_shift_in_ms:
  608. cache["stats"].max_time_out = True
  609. self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
  610. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  611. elif not is_final_frame:
  612. self.OnVoiceDetected(cur_frm_idx, cache=cache)
  613. else:
  614. self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
  615. else:
  616. pass
  617. elif AudioChangeState.kChangeStateSil2Sil == state_change:
  618. cache["stats"].continous_silence_frame_count += 1
  619. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
  620. # silence timeout, return zero length decision
  621. if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
  622. cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
  623. or (is_final_frame and cache["stats"].number_end_time_detected == 0):
  624. for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
  625. self.OnSilenceDetected(t, cache=cache)
  626. self.OnVoiceStart(0, True, cache=cache)
  627. self.OnVoiceEnd(0, True, False, cache=cache)
  628. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  629. else:
  630. if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
  631. self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
  632. elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
  633. if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
  634. lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
  635. if self.vad_opts.do_extend:
  636. lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
  637. lookback_frame -= 1
  638. lookback_frame = max(0, lookback_frame)
  639. self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
  640. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  641. elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
  642. self.vad_opts.max_single_segment_time / frm_shift_in_ms:
  643. self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
  644. cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
  645. elif self.vad_opts.do_extend and not is_final_frame:
  646. if cache["stats"].continous_silence_frame_count <= int(
  647. self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
  648. self.OnVoiceDetected(cur_frm_idx, cache=cache)
  649. else:
  650. self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
  651. else:
  652. pass
  653. if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
  654. self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
  655. self.ResetDetection(cache=cache)