embed.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. """Positional Encoding Module."""
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. from funasr.modules.embedding import (
  6. LegacyRelPositionalEncoding, PositionalEncoding, RelPositionalEncoding,
  7. ScaledPositionalEncoding, StreamPositionalEncoding)
  8. from funasr.modules.subsampling import (
  9. Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6,
  10. Conv2dSubsampling8)
  11. from funasr.modules.subsampling_without_posenc import \
  12. Conv2dSubsamplingWOPosEnc
  13. from funasr.export.models.language_models.subsampling import (
  14. OnnxConv2dSubsampling, OnnxConv2dSubsampling2, OnnxConv2dSubsampling6,
  15. OnnxConv2dSubsampling8)
  16. def get_pos_emb(pos_emb, max_seq_len=512, use_cache=True):
  17. if isinstance(pos_emb, LegacyRelPositionalEncoding):
  18. return OnnxLegacyRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
  19. elif isinstance(pos_emb, ScaledPositionalEncoding):
  20. return OnnxScaledPositionalEncoding(pos_emb, max_seq_len, use_cache)
  21. elif isinstance(pos_emb, RelPositionalEncoding):
  22. return OnnxRelPositionalEncoding(pos_emb, max_seq_len, use_cache)
  23. elif isinstance(pos_emb, PositionalEncoding):
  24. return OnnxPositionalEncoding(pos_emb, max_seq_len, use_cache)
  25. elif isinstance(pos_emb, StreamPositionalEncoding):
  26. return OnnxStreamPositionalEncoding(pos_emb, max_seq_len, use_cache)
  27. elif (isinstance(pos_emb, nn.Sequential) and len(pos_emb) == 0) or (
  28. isinstance(pos_emb, Conv2dSubsamplingWOPosEnc)
  29. ):
  30. return pos_emb
  31. else:
  32. raise ValueError("Embedding model is not supported.")
  33. class Embedding(nn.Module):
  34. def __init__(self, model, max_seq_len=512, use_cache=True):
  35. super().__init__()
  36. self.model = model
  37. if not isinstance(model, nn.Embedding):
  38. if isinstance(model, Conv2dSubsampling):
  39. self.model = OnnxConv2dSubsampling(model)
  40. self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
  41. elif isinstance(model, Conv2dSubsampling2):
  42. self.model = OnnxConv2dSubsampling2(model)
  43. self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
  44. elif isinstance(model, Conv2dSubsampling6):
  45. self.model = OnnxConv2dSubsampling6(model)
  46. self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
  47. elif isinstance(model, Conv2dSubsampling8):
  48. self.model = OnnxConv2dSubsampling8(model)
  49. self.model.out[-1] = get_pos_emb(model.out[-1], max_seq_len)
  50. else:
  51. self.model[-1] = get_pos_emb(model[-1], max_seq_len)
  52. def forward(self, x, mask=None):
  53. if mask is None:
  54. return self.model(x)
  55. else:
  56. return self.model(x, mask)
  57. def _pre_hook(
  58. state_dict,
  59. prefix,
  60. local_metadata,
  61. strict,
  62. missing_keys,
  63. unexpected_keys,
  64. error_msgs,
  65. ):
  66. """Perform pre-hook in load_state_dict for backward compatibility.
  67. Note:
  68. We saved self.pe until v.0.5.2 but we have omitted it later.
  69. Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
  70. """
  71. k = prefix + "pe"
  72. if k in state_dict:
  73. state_dict.pop(k)
  74. class OnnxPositionalEncoding(torch.nn.Module):
  75. """Positional encoding.
  76. Args:
  77. d_model (int): Embedding dimension.
  78. dropout_rate (float): Dropout rate.
  79. max_seq_len (int): Maximum input length.
  80. reverse (bool): Whether to reverse the input position. Only for
  81. the class LegacyRelPositionalEncoding. We remove it in the current
  82. class RelPositionalEncoding.
  83. """
  84. def __init__(self, model, max_seq_len=512, reverse=False, use_cache=True):
  85. """Construct an PositionalEncoding object."""
  86. super(OnnxPositionalEncoding, self).__init__()
  87. self.d_model = model.d_model
  88. self.reverse = reverse
  89. self.max_seq_len = max_seq_len
  90. self.xscale = math.sqrt(self.d_model)
  91. self._register_load_state_dict_pre_hook(_pre_hook)
  92. self.pe = model.pe
  93. self.use_cache = use_cache
  94. self.model = model
  95. if self.use_cache:
  96. self.extend_pe()
  97. else:
  98. self.div_term = torch.exp(
  99. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  100. * -(math.log(10000.0) / self.d_model)
  101. )
  102. def extend_pe(self):
  103. """Reset the positional encodings."""
  104. pe_length = len(self.pe[0])
  105. if self.max_seq_len < pe_length:
  106. self.pe = self.pe[:, : self.max_seq_len]
  107. else:
  108. self.model.extend_pe(torch.tensor(0.0).expand(1, self.max_seq_len))
  109. self.pe = self.model.pe
  110. def _add_pe(self, x):
  111. """Computes positional encoding"""
  112. if self.reverse:
  113. position = torch.arange(
  114. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  115. ).unsqueeze(1)
  116. else:
  117. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  118. x = x * self.xscale
  119. x[:, :, 0::2] += torch.sin(position * self.div_term)
  120. x[:, :, 1::2] += torch.cos(position * self.div_term)
  121. return x
  122. def forward(self, x: torch.Tensor):
  123. """Add positional encoding.
  124. Args:
  125. x (torch.Tensor): Input tensor (batch, time, `*`).
  126. Returns:
  127. torch.Tensor: Encoded tensor (batch, time, `*`).
  128. """
  129. if self.use_cache:
  130. x = x * self.xscale + self.pe[:, : x.size(1)]
  131. else:
  132. x = self._add_pe(x)
  133. return x
  134. class OnnxScaledPositionalEncoding(OnnxPositionalEncoding):
  135. """Scaled positional encoding module.
  136. See Sec. 3.2 https://arxiv.org/abs/1809.08895
  137. Args:
  138. d_model (int): Embedding dimension.
  139. dropout_rate (float): Dropout rate.
  140. max_seq_len (int): Maximum input length.
  141. """
  142. def __init__(self, model, max_seq_len=512, use_cache=True):
  143. """Initialize class."""
  144. super().__init__(model, max_seq_len, use_cache=use_cache)
  145. self.alpha = torch.nn.Parameter(torch.tensor(1.0))
  146. def reset_parameters(self):
  147. """Reset parameters."""
  148. self.alpha.data = torch.tensor(1.0)
  149. def _add_pe(self, x):
  150. """Computes positional encoding"""
  151. if self.reverse:
  152. position = torch.arange(
  153. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  154. ).unsqueeze(1)
  155. else:
  156. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  157. x = x * self.alpha
  158. x[:, :, 0::2] += torch.sin(position * self.div_term)
  159. x[:, :, 1::2] += torch.cos(position * self.div_term)
  160. return x
  161. def forward(self, x):
  162. """Add positional encoding.
  163. Args:
  164. x (torch.Tensor): Input tensor (batch, time, `*`).
  165. Returns:
  166. torch.Tensor: Encoded tensor (batch, time, `*`).
  167. """
  168. if self.use_cache:
  169. x = x + self.alpha * self.pe[:, : x.size(1)]
  170. else:
  171. x = self._add_pe(x)
  172. return x
  173. class OnnxLegacyRelPositionalEncoding(OnnxPositionalEncoding):
  174. """Relative positional encoding module (old version).
  175. Details can be found in https://github.com/espnet/espnet/pull/2816.
  176. See : Appendix B in https://arxiv.org/abs/1901.02860
  177. Args:
  178. d_model (int): Embedding dimension.
  179. dropout_rate (float): Dropout rate.
  180. max_seq_len (int): Maximum input length.
  181. """
  182. def __init__(self, model, max_seq_len=512, use_cache=True):
  183. """Initialize class."""
  184. super().__init__(model, max_seq_len, reverse=True, use_cache=use_cache)
  185. def _get_pe(self, x):
  186. """Computes positional encoding"""
  187. if self.reverse:
  188. position = torch.arange(
  189. x.size(1) - 1, -1, -1.0, dtype=torch.float32
  190. ).unsqueeze(1)
  191. else:
  192. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  193. pe = torch.zeros(x.shape)
  194. pe[:, :, 0::2] += torch.sin(position * self.div_term)
  195. pe[:, :, 1::2] += torch.cos(position * self.div_term)
  196. return pe
  197. def forward(self, x):
  198. """Compute positional encoding.
  199. Args:
  200. x (torch.Tensor): Input tensor (batch, time, `*`).
  201. Returns:
  202. torch.Tensor: Encoded tensor (batch, time, `*`).
  203. torch.Tensor: Positional embedding tensor (1, time, `*`).
  204. """
  205. x = x * self.xscale
  206. if self.use_cache:
  207. pos_emb = self.pe[:, : x.size(1)]
  208. else:
  209. pos_emb = self._get_pe(x)
  210. return x, pos_emb
  211. class OnnxRelPositionalEncoding(torch.nn.Module):
  212. """Relative positional encoding module (new implementation).
  213. Details can be found in https://github.com/espnet/espnet/pull/2816.
  214. See : Appendix B in https://arxiv.org/abs/1901.02860
  215. Args:
  216. d_model (int): Embedding dimension.
  217. dropout_rate (float): Dropout rate.
  218. max_seq_len (int): Maximum input length.
  219. """
  220. def __init__(self, model, max_seq_len=512, use_cache=True):
  221. """Construct an PositionalEncoding object."""
  222. super(OnnxRelPositionalEncoding, self).__init__()
  223. self.d_model = model.d_model
  224. self.xscale = math.sqrt(self.d_model)
  225. self.pe = None
  226. self.use_cache = use_cache
  227. if self.use_cache:
  228. self.extend_pe(torch.tensor(0.0).expand(1, max_seq_len))
  229. else:
  230. self.div_term = torch.exp(
  231. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  232. * -(math.log(10000.0) / self.d_model)
  233. )
  234. def extend_pe(self, x):
  235. """Reset the positional encodings."""
  236. if self.pe is not None and self.pe.size(1) >= x.size(1) * 2 - 1:
  237. # self.pe contains both positive and negative parts
  238. # the length of self.pe is 2 * input_len - 1
  239. if self.pe.dtype != x.dtype or self.pe.device != x.device:
  240. self.pe = self.pe.to(dtype=x.dtype, device=x.device)
  241. return
  242. # Suppose `i` means to the position of query vecotr and `j` means the
  243. # position of key vector. We use position relative positions when keys
  244. # are to the left (i>j) and negative relative positions otherwise (i<j).
  245. pe_positive = torch.zeros(x.size(1), self.d_model)
  246. pe_negative = torch.zeros(x.size(1), self.d_model)
  247. position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
  248. div_term = torch.exp(
  249. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  250. * -(math.log(10000.0) / self.d_model)
  251. )
  252. pe_positive[:, 0::2] = torch.sin(position * div_term)
  253. pe_positive[:, 1::2] = torch.cos(position * div_term)
  254. pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
  255. pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
  256. # Reserve the order of positive indices and concat both positive and
  257. # negative indices. This is used to support the shifting trick
  258. # as in https://arxiv.org/abs/1901.02860
  259. pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
  260. pe_negative = pe_negative[1:].unsqueeze(0)
  261. pe = torch.cat([pe_positive, pe_negative], dim=1)
  262. self.pe = pe.to(device=x.device, dtype=x.dtype)
  263. def _get_pe(self, x):
  264. pe_positive = torch.zeros(x.size(1), self.d_model)
  265. pe_negative = torch.zeros(x.size(1), self.d_model)
  266. theta = (
  267. torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) * self.div_term
  268. )
  269. pe_positive[:, 0::2] = torch.sin(theta)
  270. pe_positive[:, 1::2] = torch.cos(theta)
  271. pe_negative[:, 0::2] = -1 * torch.sin(theta)
  272. pe_negative[:, 1::2] = torch.cos(theta)
  273. # Reserve the order of positive indices and concat both positive and
  274. # negative indices. This is used to support the shifting trick
  275. # as in https://arxiv.org/abs/1901.02860
  276. pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
  277. pe_negative = pe_negative[1:].unsqueeze(0)
  278. return torch.cat([pe_positive, pe_negative], dim=1)
  279. def forward(self, x: torch.Tensor, use_cache=True):
  280. """Add positional encoding.
  281. Args:
  282. x (torch.Tensor): Input tensor (batch, time, `*`).
  283. Returns:
  284. torch.Tensor: Encoded tensor (batch, time, `*`).
  285. """
  286. x = x * self.xscale
  287. if self.use_cache:
  288. pos_emb = self.pe[
  289. :,
  290. self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
  291. ]
  292. else:
  293. pos_emb = self._get_pe(x)
  294. return x, pos_emb
  295. class OnnxStreamPositionalEncoding(torch.nn.Module):
  296. """Streaming Positional encoding."""
  297. def __init__(self, model, max_seq_len=5000, use_cache=True):
  298. """Construct an PositionalEncoding object."""
  299. super(StreamPositionalEncoding, self).__init__()
  300. self.use_cache = use_cache
  301. self.d_model = model.d_model
  302. self.xscale = model.xscale
  303. self.pe = model.pe
  304. self.use_cache = use_cache
  305. self.max_seq_len = max_seq_len
  306. if self.use_cache:
  307. self.extend_pe()
  308. else:
  309. self.div_term = torch.exp(
  310. torch.arange(0, self.d_model, 2, dtype=torch.float32)
  311. * -(math.log(10000.0) / self.d_model)
  312. )
  313. self._register_load_state_dict_pre_hook(_pre_hook)
  314. def extend_pe(self):
  315. """Reset the positional encodings."""
  316. pe_length = len(self.pe[0])
  317. if self.max_seq_len < pe_length:
  318. self.pe = self.pe[:, : self.max_seq_len]
  319. else:
  320. self.model.extend_pe(self.max_seq_len)
  321. self.pe = self.model.pe
  322. def _add_pe(self, x, start_idx):
  323. position = torch.arange(start_idx, x.size(1), dtype=torch.float32).unsqueeze(1)
  324. x = x * self.xscale
  325. x[:, :, 0::2] += torch.sin(position * self.div_term)
  326. x[:, :, 1::2] += torch.cos(position * self.div_term)
  327. return x
  328. def forward(self, x: torch.Tensor, start_idx: int = 0):
  329. """Add positional encoding.
  330. Args:
  331. x (torch.Tensor): Input tensor (batch, time, `*`).
  332. Returns:
  333. torch.Tensor: Encoded tensor (batch, time, `*`).
  334. """
  335. if self.use_cache:
  336. return x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
  337. else:
  338. return self._add_pe(x, start_idx)