|
|
@@ -395,8 +395,10 @@ class WavFrontendOnline(AbsFrontend):
|
|
|
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
|
|
|
|
|
def forward(
|
|
|
- self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
|
|
|
+ self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False, reset: bool = True
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
+ if reset:
|
|
|
+ self.cache_reset()
|
|
|
batch_size = input.shape[0]
|
|
|
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
|
|
|
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
|
|
|
@@ -500,4 +502,4 @@ class WavFrontendMel23(AbsFrontend):
|
|
|
feats_pad = pad_sequence(feats,
|
|
|
batch_first=True,
|
|
|
padding_value=0.0)
|
|
|
- return feats_pad, feats_lens
|
|
|
+ return feats_pad, feats_lens
|