|
|
@@ -192,7 +192,7 @@ class WindowDetector(object):
|
|
|
|
|
|
|
|
|
class E2EVadModel(nn.Module):
|
|
|
- def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
|
|
|
+ def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
|
|
|
super(E2EVadModel, self).__init__()
|
|
|
self.vad_opts = VADXOptions(**vad_post_args)
|
|
|
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
|
|
|
@@ -227,7 +227,6 @@ class E2EVadModel(nn.Module):
|
|
|
self.data_buf = None
|
|
|
self.data_buf_all = None
|
|
|
self.waveform = None
|
|
|
- self.streaming = streaming
|
|
|
self.ResetDetection()
|
|
|
|
|
|
def AllResetDetection(self):
|
|
|
@@ -451,11 +450,7 @@ class E2EVadModel(nn.Module):
|
|
|
if not is_final_send:
|
|
|
self.DetectCommonFrames()
|
|
|
else:
|
|
|
- if self.streaming:
|
|
|
- self.DetectLastFrames()
|
|
|
- else:
|
|
|
- self.AllResetDetection()
|
|
|
- self.DetectAllFrames() # offline decode and is_final_send == True
|
|
|
+ self.DetectLastFrames()
|
|
|
segments = []
|
|
|
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
|
|
|
segment_batch = []
|
|
|
@@ -468,7 +463,8 @@ class E2EVadModel(nn.Module):
|
|
|
self.output_data_buf_offset += 1 # need update this parameter
|
|
|
if segment_batch:
|
|
|
segments.append(segment_batch)
|
|
|
-
|
|
|
+ if is_final_send:
|
|
|
+ self.AllResetDetection()
|
|
|
return segments
|
|
|
|
|
|
def DetectCommonFrames(self) -> int:
|
|
|
@@ -494,18 +490,6 @@ class E2EVadModel(nn.Module):
|
|
|
|
|
|
return 0
|
|
|
|
|
|
- def DetectAllFrames(self) -> int:
|
|
|
- if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
|
|
- return 0
|
|
|
- if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
|
|
|
- frame_state = FrameState.kFrameStateInvalid
|
|
|
- for t in range(0, self.frm_cnt):
|
|
|
- frame_state = self.GetFrameState(t)
|
|
|
- self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1)
|
|
|
- else:
|
|
|
- pass
|
|
|
- return 0
|
|
|
-
|
|
|
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
|
|
|
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
|
|
if cur_frm_state == FrameState.kFrameStateSpeech:
|