|
@@ -221,13 +221,14 @@ class CifPredictorV2(nn.Module):
|
|
|
|
|
|
|
|
if cache is not None and "chunk_size" in cache:
|
|
if cache is not None and "chunk_size" in cache:
|
|
|
alphas[:, :cache["chunk_size"][0]] = 0.0
|
|
alphas[:, :cache["chunk_size"][0]] = 0.0
|
|
|
- alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
|
|
|
|
|
|
+ if "is_final" in cache and not cache["is_final"]:
|
|
|
|
|
+ alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
|
|
|
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
|
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
|
|
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
|
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
|
|
|
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
|
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
|
|
|
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
|
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
|
|
|
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
|
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
|
|
|
- if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
|
|
|
|
|
|
+ if cache is not None and "is_final" in cache and cache["is_final"]:
|
|
|
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
|
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
|
|
|
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
|
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
|
|
|
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|
|
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))
|