embedding.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Positional Encoding Module."""
  6. import math
  7. import torch
  8. def _pre_hook(
  9. state_dict,
  10. prefix,
  11. local_metadata,
  12. strict,
  13. missing_keys,
  14. unexpected_keys,
  15. error_msgs,
  16. ):
  17. """Perform pre-hook in load_state_dict for backward compatibility.
  18. Note:
  19. We saved self.pe until v.0.5.2 but we have omitted it later.
  20. Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
  21. """
  22. k = prefix + "pe"
  23. if k in state_dict:
  24. state_dict.pop(k)
  25. class PositionalEncoding(torch.nn.Module):
  26. """Positional encoding.
  27. Args:
  28. d_model (int): Embedding dimension.
  29. dropout_rate (float): Dropout rate.
  30. max_len (int): Maximum input length.
  31. reverse (bool): Whether to reverse the input position. Only for
  32. the class LegacyRelPositionalEncoding. We remove it in the current
  33. class RelPositionalEncoding.
  34. """
  35. def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
  36. """Construct an PositionalEncoding object."""
  37. super(PositionalEncoding, self).__init__()
  38. self.d_model = d_model
  39. self.reverse = reverse
  40. self.xscale = math.sqrt(self.d_model)
  41. self.dropout = torch.nn.Dropout(p=dropout_rate)
  42. self.pe = None
  43. self.extend_pe(torch.tensor(0.0).expand(1, max_len))
  44. self._register_load_state_dict_pre_hook(_pre_hook)
  45. def extend_pe(self, x):
  46. """Reset the positional encodings."""
  47. if self.pe is not None:
  48. if self.pe.size(1) >= x.size(1):
  49. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  50. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  51. return
  52. pe = torch.zeros(x.size(1), self.d_model)
  53. if self.reverse:
  54. position = torch.arange(
  55. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  56. ).unsqueeze(1)
  57. else:
  58. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  59. div_term = torch.exp(
  60. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  61. * -(math.log(10000.0) / self.d_model)
  62. )
  63. pe[:, 0::2] = torch.sin(position * div_term)
  64. pe[:, 1::2] = torch.cos(position * div_term)
  65. pe = pe.unsqueeze(0)
  66. self.pe = pe.to(device=x.device, dtype=x.dtype)
  67. def forward(self, x: torch.Tensor):
  68. """Add positional encoding.
  69. Args:
  70. x (torch.Tensor): Input tensor (batch, time, `*`).
  71. Returns:
  72. torch.Tensor: Encoded tensor (batch, time, `*`).
  73. """
  74. self.extend_pe(x)
  75. x = x * self.xscale + self.pe[:, : x.size(1)]
  76. return self.dropout(x)
  77. class ScaledPositionalEncoding(PositionalEncoding):
  78. """Scaled positional encoding module.
  79. See Sec. 3.2 https://arxiv.org/abs/1809.08895
  80. Args:
  81. d_model (int): Embedding dimension.
  82. dropout_rate (float): Dropout rate.
  83. max_len (int): Maximum input length.
  84. """
  85. def __init__(self, d_model, dropout_rate, max_len=5000):
  86. """Initialize class."""
  87. super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
  88. self.alpha = torch.nn.Parameter(torch.tensor(1.0))
  89. def reset_parameters(self):
  90. """Reset parameters."""
  91. self.alpha.data = torch.tensor(1.0)
  92. def forward(self, x):
  93. """Add positional encoding.
  94. Args:
  95. x (torch.Tensor): Input tensor (batch, time, `*`).
  96. Returns:
  97. torch.Tensor: Encoded tensor (batch, time, `*`).
  98. """
  99. self.extend_pe(x)
  100. x = x + self.alpha * self.pe[:, : x.size(1)]
  101. return self.dropout(x)
  102. class LearnableFourierPosEnc(torch.nn.Module):
  103. """Learnable Fourier Features for Positional Encoding.
  104. See https://arxiv.org/pdf/2106.02795.pdf
  105. Args:
  106. d_model (int): Embedding dimension.
  107. dropout_rate (float): Dropout rate.
  108. max_len (int): Maximum input length.
  109. gamma (float): init parameter for the positional kernel variance
  110. see https://arxiv.org/pdf/2106.02795.pdf.
  111. apply_scaling (bool): Whether to scale the input before adding the pos encoding.
  112. hidden_dim (int): if not None, we modulate the pos encodings with
  113. an MLP whose hidden layer has hidden_dim neurons.
  114. """
  115. def __init__(
  116. self,
  117. d_model,
  118. dropout_rate=0.0,
  119. max_len=5000,
  120. gamma=1.0,
  121. apply_scaling=False,
  122. hidden_dim=None,
  123. ):
  124. """Initialize class."""
  125. super(LearnableFourierPosEnc, self).__init__()
  126. self.d_model = d_model
  127. if apply_scaling:
  128. self.xscale = math.sqrt(self.d_model)
  129. else:
  130. self.xscale = 1.0
  131. self.dropout = torch.nn.Dropout(dropout_rate)
  132. self.max_len = max_len
  133. self.gamma = gamma
  134. if self.gamma is None:
  135. self.gamma = self.d_model // 2
  136. assert (
  137. d_model % 2 == 0
  138. ), "d_model should be divisible by two in order to use this layer."
  139. self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
  140. self._reset() # init the weights
  141. self.hidden_dim = hidden_dim
  142. if self.hidden_dim is not None:
  143. self.mlp = torch.nn.Sequential(
  144. torch.nn.Linear(d_model, hidden_dim),
  145. torch.nn.GELU(),
  146. torch.nn.Linear(hidden_dim, d_model),
  147. )
  148. def _reset(self):
  149. self.w_r.data = torch.normal(
  150. 0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
  151. )
  152. def extend_pe(self, x):
  153. """Reset the positional encodings."""
  154. position_v = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1).to(x)
  155. cosine = torch.cos(torch.matmul(position_v, self.w_r))
  156. sine = torch.sin(torch.matmul(position_v, self.w_r))
  157. pos_enc = torch.cat((cosine, sine), -1)
  158. pos_enc /= math.sqrt(self.d_model)
  159. if self.hidden_dim is None:
  160. return pos_enc.unsqueeze(0)
  161. else:
  162. return self.mlp(pos_enc.unsqueeze(0))
  163. def forward(self, x: torch.Tensor):
  164. """Add positional encoding.
  165. Args:
  166. x (torch.Tensor): Input tensor (batch, time, `*`).
  167. Returns:
  168. torch.Tensor: Encoded tensor (batch, time, `*`).
  169. """
  170. pe = self.extend_pe(x)
  171. x = x * self.xscale + pe
  172. return self.dropout(x)
  173. class LegacyRelPositionalEncoding(PositionalEncoding):
  174. """Relative positional encoding module (old version).
  175. Details can be found in https://github.com/espnet/espnet/pull/2816.
  176. See : Appendix B in https://arxiv.org/abs/1901.02860
  177. Args:
  178. d_model (int): Embedding dimension.
  179. dropout_rate (float): Dropout rate.
  180. max_len (int): Maximum input length.
  181. """
  182. def __init__(self, d_model, dropout_rate, max_len=5000):
  183. """Initialize class."""
  184. super().__init__(
  185. d_model=d_model,
  186. dropout_rate=dropout_rate,
  187. max_len=max_len,
  188. reverse=True,
  189. )
  190. def forward(self, x):
  191. """Compute positional encoding.
  192. Args:
  193. x (torch.Tensor): Input tensor (batch, time, `*`).
  194. Returns:
  195. torch.Tensor: Encoded tensor (batch, time, `*`).
  196. torch.Tensor: Positional embedding tensor (1, time, `*`).
  197. """
  198. self.extend_pe(x)
  199. x = x * self.xscale
  200. pos_emb = self.pe[:, : x.size(1)]
  201. return self.dropout(x), self.dropout(pos_emb)
  202. class RelPositionalEncoding(torch.nn.Module):
  203. """Relative positional encoding module (new implementation).
  204. Details can be found in https://github.com/espnet/espnet/pull/2816.
  205. See : Appendix B in https://arxiv.org/abs/1901.02860
  206. Args:
  207. d_model (int): Embedding dimension.
  208. dropout_rate (float): Dropout rate.
  209. max_len (int): Maximum input length.
  210. """
  211. def __init__(self, d_model, dropout_rate, max_len=5000):
  212. """Construct an PositionalEncoding object."""
  213. super(RelPositionalEncoding, self).__init__()
  214. self.d_model = d_model
  215. self.xscale = math.sqrt(self.d_model)
  216. self.dropout = torch.nn.Dropout(p=dropout_rate)
  217. self.pe = None
  218. self.extend_pe(torch.tensor(0.0).expand(1, max_len))
  219. def extend_pe(self, x):
  220. """Reset the positional encodings."""
  221. if self.pe is not None:
  222. # self.pe contains both positive and negative parts
  223. # the length of self.pe is 2 * input_len - 1
  224. if self.pe.size(1) >= x.size(1) * 2 - 1:
  225. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  226. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  227. return
  228. # Suppose `i` means to the position of query vecotr and `j` means the
  229. # position of key vector. We use position relative positions when keys
  230. # are to the left (i>j) and negative relative positions otherwise (i<j).
  231. pe_positive = torch.zeros(x.size(1), self.d_model)
  232. pe_negative = torch.zeros(x.size(1), self.d_model)
  233. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  234. div_term = torch.exp(
  235. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  236. * -(math.log(10000.0) / self.d_model)
  237. )
  238. pe_positive[:, 0::2] = torch.sin(position * div_term)
  239. pe_positive[:, 1::2] = torch.cos(position * div_term)
  240. pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
  241. pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
  242. # Reserve the order of positive indices and concat both positive and
  243. # negative indices. This is used to support the shifting trick
  244. # as in https://arxiv.org/abs/1901.02860
  245. pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
  246. pe_negative = pe_negative[1:].unsqueeze(0)
  247. pe = torch.cat([pe_positive, pe_negative], dim=1)
  248. self.pe = pe.to(device=x.device, dtype=x.dtype)
  249. def forward(self, x: torch.Tensor):
  250. """Add positional encoding.
  251. Args:
  252. x (torch.Tensor): Input tensor (batch, time, `*`).
  253. Returns:
  254. torch.Tensor: Encoded tensor (batch, time, `*`).
  255. """
  256. self.extend_pe(x)
  257. x = x * self.xscale
  258. pos_emb = self.pe[
  259. :,
  260. self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
  261. ]
  262. return self.dropout(x), self.dropout(pos_emb)
  263. class StreamPositionalEncoding(torch.nn.Module):
  264. """Streaming Positional encoding.
  265. Args:
  266. d_model (int): Embedding dimension.
  267. dropout_rate (float): Dropout rate.
  268. max_len (int): Maximum input length.
  269. """
  270. def __init__(self, d_model, dropout_rate, max_len=5000):
  271. """Construct an PositionalEncoding object."""
  272. super(StreamPositionalEncoding, self).__init__()
  273. self.d_model = d_model
  274. self.xscale = math.sqrt(self.d_model)
  275. self.dropout = torch.nn.Dropout(p=dropout_rate)
  276. self.pe = None
  277. self.tmp = torch.tensor(0.0).expand(1, max_len)
  278. self.extend_pe(self.tmp.size(1), self.tmp.device, self.tmp.dtype)
  279. self._register_load_state_dict_pre_hook(_pre_hook)
  280. def extend_pe(self, length, device, dtype):
  281. """Reset the positional encodings."""
  282. if self.pe is not None:
  283. if self.pe.size(1) >= length:
  284. if self.pe.dtype != dtype or self.pe.device != device:
  285. self.pe = self.pe.to(dtype=dtype, device=device)
  286. return
  287. pe = torch.zeros(length, self.d_model)
  288. position = torch.arange(0, length, dtype=torch.float32).unsqueeze(1)
  289. div_term = torch.exp(
  290. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  291. * -(math.log(10000.0) / self.d_model)
  292. )
  293. pe[:, 0::2] = torch.sin(position * div_term)
  294. pe[:, 1::2] = torch.cos(position * div_term)
  295. pe = pe.unsqueeze(0)
  296. self.pe = pe.to(device=device, dtype=dtype)
  297. def forward(self, x: torch.Tensor, start_idx: int = 0):
  298. """Add positional encoding.
  299. Args:
  300. x (torch.Tensor): Input tensor (batch, time, `*`).
  301. Returns:
  302. torch.Tensor: Encoded tensor (batch, time, `*`).
  303. """
  304. self.extend_pe(x.size(1) + start_idx, x.device, x.dtype)
  305. x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
  306. return self.dropout(x)
  307. class SinusoidalPositionEncoder(torch.nn.Module):
  308. '''
  309. '''
  310. def __int__(self, d_model=80, dropout_rate=0.1):
  311. pass
  312. def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
  313. batch_size = positions.size(0)
  314. positions = positions.type(dtype)
  315. log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
  316. inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
  317. inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
  318. scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
  319. encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
  320. return encoding.type(dtype)
  321. def forward(self, x):
  322. batch_size, timesteps, input_dim = x.size()
  323. positions = torch.arange(1, timesteps+1)[None, :]
  324. position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
  325. return x + position_encoding