|
|
@@ -19,7 +19,7 @@ def cif_wo_hidden(alphas, threshold):
|
|
|
list_fires.append(integrate)
|
|
|
fire_place = integrate >= threshold
|
|
|
integrate = torch.where(fire_place,
|
|
|
- integrate - torch.ones([batch_size], device=alphas.device),
|
|
|
+ integrate - torch.ones([batch_size], device=alphas.device)*threshold,
|
|
|
integrate)
|
|
|
fires = torch.stack(list_fires, 1)
|
|
|
return fires
|