|
|
@@ -158,13 +158,14 @@ class SANMVadEncoder(nn.Module):
|
|
|
def forward(self,
|
|
|
speech: torch.Tensor,
|
|
|
speech_lengths: torch.Tensor,
|
|
|
- vad_mask: torch.Tensor,
|
|
|
+ vad_masks: torch.Tensor,
|
|
|
sub_masks: torch.Tensor,
|
|
|
):
|
|
|
speech = speech * self._output_size ** 0.5
|
|
|
mask = self.make_pad_mask(speech_lengths)
|
|
|
+ vad_masks = self.prepare_mask(mask, vad_masks)
|
|
|
mask = self.prepare_mask(mask, sub_masks)
|
|
|
- vad_mask = self.prepare_mask(mask, vad_mask)
|
|
|
+
|
|
|
if self.embed is None:
|
|
|
xs_pad = speech
|
|
|
else:
|
|
|
@@ -176,7 +177,7 @@ class SANMVadEncoder(nn.Module):
|
|
|
# encoder_outs = self.model.encoders(xs_pad, mask)
|
|
|
for layer_idx, encoder_layer in enumerate(self.model.encoders):
|
|
|
if layer_idx == len(self.model.encoders) - 1:
|
|
|
- mask = vad_mask
|
|
|
+ mask = vad_masks
|
|
|
encoder_outs = encoder_layer(xs_pad, mask)
|
|
|
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
|
|
|
|
@@ -187,26 +188,26 @@ class SANMVadEncoder(nn.Module):
|
|
|
def get_output_size(self):
|
|
|
return self.model.encoders[0].size
|
|
|
|
|
|
- def get_dummy_inputs(self):
|
|
|
- feats = torch.randn(1, 100, self.feats_dim)
|
|
|
- return (feats)
|
|
|
-
|
|
|
- def get_input_names(self):
|
|
|
- return ['feats']
|
|
|
-
|
|
|
- def get_output_names(self):
|
|
|
- return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
|
|
-
|
|
|
- def get_dynamic_axes(self):
|
|
|
- return {
|
|
|
- 'feats': {
|
|
|
- 1: 'feats_length'
|
|
|
- },
|
|
|
- 'encoder_out': {
|
|
|
- 1: 'enc_out_length'
|
|
|
- },
|
|
|
- 'predictor_weight': {
|
|
|
- 1: 'pre_out_length'
|
|
|
- }
|
|
|
-
|
|
|
- }
|
|
|
+ # def get_dummy_inputs(self):
|
|
|
+ # feats = torch.randn(1, 100, self.feats_dim)
|
|
|
+ # return (feats)
|
|
|
+ #
|
|
|
+ # def get_input_names(self):
|
|
|
+ # return ['feats']
|
|
|
+ #
|
|
|
+ # def get_output_names(self):
|
|
|
+ # return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
|
|
+ #
|
|
|
+ # def get_dynamic_axes(self):
|
|
|
+ # return {
|
|
|
+ # 'feats': {
|
|
|
+ # 1: 'feats_length'
|
|
|
+ # },
|
|
|
+ # 'encoder_out': {
|
|
|
+ # 1: 'enc_out_length'
|
|
|
+ # },
|
|
|
+ # 'predictor_weight': {
|
|
|
+ # 1: 'pre_out_length'
|
|
|
+ # }
|
|
|
+ #
|
|
|
+ # }
|