|
|
@@ -31,10 +31,12 @@ class CifPredictor(nn.Module):
|
|
|
alphas = torch.sigmoid(output)
|
|
|
alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
|
|
if mask is not None:
|
|
|
- alphas = alphas * mask.transpose(-1, -2).float()
|
|
|
+ mask = mask.transpose(-1, -2).float()
|
|
|
+ alphas = alphas * mask
|
|
|
if mask_chunk_predictor is not None:
|
|
|
alphas = alphas * mask_chunk_predictor
|
|
|
alphas = alphas.squeeze(-1)
|
|
|
+ mask = mask.squeeze(-1)
|
|
|
if target_label_length is not None:
|
|
|
target_length = target_label_length
|
|
|
elif target_label is not None:
|