|
|
@@ -471,15 +471,21 @@ class MultiHeadedAttentionSANM(nn.Module):
|
|
|
|
|
|
"""
|
|
|
q_h, k_h, v_h, v = self.forward_qkv(x)
|
|
|
- if chunk_size is not None and look_back > 0:
|
|
|
+ if chunk_size is not None and look_back > 0 or look_back == -1:
|
|
|
if cache is not None:
|
|
|
+ k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
|
|
|
+ v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
|
|
|
k_h = torch.cat((cache["k"], k_h), dim=2)
|
|
|
v_h = torch.cat((cache["v"], v_h), dim=2)
|
|
|
- cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
|
|
|
- cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
|
|
|
+
|
|
|
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
|
|
|
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
|
|
|
+ if look_back != -1:
|
|
|
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
|
|
|
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
|
|
|
else:
|
|
|
- cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
|
|
|
- "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
|
|
|
+ cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
|
|
|
+ "v": v_h[:, :, :-(chunk_size[2]), :]}
|
|
|
cache = cache_tmp
|
|
|
fsmn_memory = self.forward_fsmn(v, None)
|
|
|
q_h = q_h * self.d_k ** (-0.5)
|