|
|
@@ -63,11 +63,11 @@ class VadRealtimeTransformer(nn.Module):
|
|
|
text_lengths = torch.tensor([length], dtype=torch.int32)
|
|
|
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
|
|
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
|
|
- sub_masks = torch.tril(sub_masks)
|
|
|
- return (text_indexes, text_lengths, vad_mask, sub_masks)
|
|
|
+ sub_masks = torch.tril(sub_masks).type(torch.float32)
|
|
|
+ return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
|
|
|
|
|
def get_input_names(self):
|
|
|
- return ['input', 'text_lengths', 'vad_mask']
|
|
|
+ return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
|
|
|
|
|
|
def get_output_names(self):
|
|
|
return ['logits']
|
|
|
@@ -81,6 +81,10 @@ class VadRealtimeTransformer(nn.Module):
|
|
|
2: 'feats_length1',
|
|
|
3: 'feats_length2'
|
|
|
},
|
|
|
+ 'sub_masks': {
|
|
|
+ 2: 'feats_length1',
|
|
|
+ 3: 'feats_length2'
|
|
|
+ },
|
|
|
'logits': {
|
|
|
1: 'logits_length'
|
|
|
},
|