|
|
@@ -68,7 +68,8 @@ class CifPredictor(nn.Module):
|
|
|
mask_2 = torch.cat([ones_t, mask], dim=1)
|
|
|
mask = mask_2 - mask_1
|
|
|
tail_threshold = mask * tail_threshold
|
|
|
- alphas = torch.cat([alphas, tail_threshold], dim=1)
|
|
|
+ alphas = torch.cat([alphas, zeros_t], dim=1)
|
|
|
+ alphas = torch.add(alphas, tail_threshold)
|
|
|
else:
|
|
|
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
|
|
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|
|
|
@@ -597,7 +598,8 @@ class CifPredictorV3(nn.Module):
|
|
|
mask_2 = torch.cat([ones_t, mask], dim=1)
|
|
|
mask = mask_2 - mask_1
|
|
|
tail_threshold = mask * tail_threshold
|
|
|
- alphas = torch.cat([alphas, tail_threshold], dim=1)
|
|
|
+ alphas = torch.cat([alphas, zeros_t], dim=1)
|
|
|
+ alphas = torch.add(alphas, tail_threshold)
|
|
|
else:
|
|
|
tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
|
|
|
tail_threshold = torch.reshape(tail_threshold, (1, 1))
|