|
|
@@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer):
|
|
|
mini_sentence = cache_sent + mini_sentence
|
|
|
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
|
|
|
text_length = len(mini_sentence_id)
|
|
|
+ vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
|
|
|
data = {
|
|
|
"input": mini_sentence_id[None,:],
|
|
|
"text_lengths": np.array([text_length], dtype='int32'),
|
|
|
- "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
|
|
|
- "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
|
|
+ "vad_mask": vad_mask
|
|
|
+ "sub_masks": vad_mask
|
|
|
}
|
|
|
try:
|
|
|
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
|