rwkv_attention.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import math
  6. import torch
  7. from pathlib import Path
  8. from importlib.util import find_spec
  9. from typing import List, Optional, Tuple, Union
  10. wkv_kernel_encoder = None
  11. wkv_kernel_decoder = None
  12. class WKVLinearAttentionEncoder(torch.autograd.Function):
  13. """WKVLinearAttention function definition."""
  14. @staticmethod
  15. def forward(
  16. ctx,
  17. time_decay: torch.Tensor,
  18. time_first: torch.Tensor,
  19. key: torch.Tensor,
  20. value: torch.tensor,
  21. ) -> torch.Tensor:
  22. """WKVLinearAttention function forward pass.
  23. Args:
  24. time_decay: Channel-wise time decay vector. (D_att)
  25. time_first: Channel-wise time first vector. (D_att)
  26. key: Key tensor. (B, U, D_att)
  27. value: Value tensor. (B, U, D_att)
  28. Returns:
  29. out: Weighted Key-Value tensor. (B, U, D_att)
  30. """
  31. batch, length, dim = key.size()
  32. assert length <= wkv_kernel_encoder.context_size, (
  33. f"Cannot process key of length {length} while context_size "
  34. f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
  35. )
  36. assert batch * dim % min(dim, 32) == 0, (
  37. f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
  38. f"{min(dim, 32)}"
  39. )
  40. ctx.input_dtype = key.dtype
  41. time_decay = -torch.exp(time_decay.float().contiguous())
  42. time_first = time_first.float().contiguous()
  43. key = key.float().contiguous()
  44. value = value.float().contiguous()
  45. out = torch.empty_like(key, memory_format=torch.contiguous_format)
  46. wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
  47. ctx.save_for_backward(time_decay, time_first, key, value, out)
  48. return out
  49. @staticmethod
  50. def backward(
  51. ctx, grad_output: torch.Tensor
  52. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  53. """WKVLinearAttention function backward pass.
  54. Args:
  55. grad_output: Output gradient. (B, U, D_att)
  56. Returns:
  57. grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
  58. grad_time_first: Gradient for channel-wise time first vector. (D_att)
  59. grad_key: Gradient for key tensor. (B, U, D_att)
  60. grad_value: Gradient for value tensor. (B, U, D_att)
  61. """
  62. time_decay, time_first, key, value, output = ctx.saved_tensors
  63. grad_dtype = ctx.input_dtype
  64. batch, _, dim = key.size()
  65. grad_time_decay = torch.empty(
  66. (batch, dim),
  67. memory_format=torch.contiguous_format,
  68. dtype=time_decay.dtype,
  69. device=time_decay.device,
  70. )
  71. grad_time_first = torch.empty(
  72. (batch, dim),
  73. memory_format=torch.contiguous_format,
  74. dtype=time_decay.dtype,
  75. device=time_decay.device,
  76. )
  77. grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
  78. grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
  79. wkv_kernel_encoder.backward(
  80. time_decay,
  81. time_first,
  82. key,
  83. value,
  84. output,
  85. grad_output.contiguous(),
  86. grad_time_decay,
  87. grad_time_first,
  88. grad_key,
  89. grad_value,
  90. )
  91. grad_time_decay = torch.sum(grad_time_decay, dim=0)
  92. grad_time_first = torch.sum(grad_time_first, dim=0)
  93. return (
  94. grad_time_decay,
  95. grad_time_first,
  96. grad_key,
  97. grad_value,
  98. )
  99. class WKVLinearAttentionDecoder(torch.autograd.Function):
  100. """WKVLinearAttention function definition."""
  101. @staticmethod
  102. def forward(
  103. ctx,
  104. time_decay: torch.Tensor,
  105. time_first: torch.Tensor,
  106. key: torch.Tensor,
  107. value: torch.tensor,
  108. ) -> torch.Tensor:
  109. """WKVLinearAttention function forward pass.
  110. Args:
  111. time_decay: Channel-wise time decay vector. (D_att)
  112. time_first: Channel-wise time first vector. (D_att)
  113. key: Key tensor. (B, U, D_att)
  114. value: Value tensor. (B, U, D_att)
  115. Returns:
  116. out: Weighted Key-Value tensor. (B, U, D_att)
  117. """
  118. batch, length, dim = key.size()
  119. assert length <= wkv_kernel_decoder.context_size, (
  120. f"Cannot process key of length {length} while context_size "
  121. f"is ({wkv_kernel.context_size}). Limit should be increased."
  122. )
  123. assert batch * dim % min(dim, 32) == 0, (
  124. f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
  125. f"{min(dim, 32)}"
  126. )
  127. ctx.input_dtype = key.dtype
  128. time_decay = -torch.exp(time_decay.float().contiguous())
  129. time_first = time_first.float().contiguous()
  130. key = key.float().contiguous()
  131. value = value.float().contiguous()
  132. out = torch.empty_like(key, memory_format=torch.contiguous_format)
  133. wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
  134. ctx.save_for_backward(time_decay, time_first, key, value, out)
  135. return out
  136. @staticmethod
  137. def backward(
  138. ctx, grad_output: torch.Tensor
  139. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  140. """WKVLinearAttention function backward pass.
  141. Args:
  142. grad_output: Output gradient. (B, U, D_att)
  143. Returns:
  144. grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
  145. grad_time_first: Gradient for channel-wise time first vector. (D_att)
  146. grad_key: Gradient for key tensor. (B, U, D_att)
  147. grad_value: Gradient for value tensor. (B, U, D_att)
  148. """
  149. time_decay, time_first, key, value, output = ctx.saved_tensors
  150. grad_dtype = ctx.input_dtype
  151. batch, _, dim = key.size()
  152. grad_time_decay = torch.empty(
  153. (batch, dim),
  154. memory_format=torch.contiguous_format,
  155. dtype=time_decay.dtype,
  156. device=time_decay.device,
  157. )
  158. grad_time_first = torch.empty(
  159. (batch, dim),
  160. memory_format=torch.contiguous_format,
  161. dtype=time_decay.dtype,
  162. device=time_decay.device,
  163. )
  164. grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
  165. grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
  166. wkv_kernel_decoder.backward(
  167. time_decay,
  168. time_first,
  169. key,
  170. value,
  171. output,
  172. grad_output.contiguous(),
  173. grad_time_decay,
  174. grad_time_first,
  175. grad_key,
  176. grad_value,
  177. )
  178. grad_time_decay = torch.sum(grad_time_decay, dim=0)
  179. grad_time_first = torch.sum(grad_time_first, dim=0)
  180. return (
  181. grad_time_decay,
  182. grad_time_first,
  183. grad_key,
  184. grad_value,
  185. )
  186. def load_encoder_wkv_kernel(context_size: int) -> None:
  187. """Load WKV CUDA kernel.
  188. Args:
  189. context_size: Context size.
  190. """
  191. from torch.utils.cpp_extension import load
  192. global wkv_kernel_encoder
  193. if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
  194. return
  195. if find_spec("ninja") is None:
  196. raise ImportError(
  197. "Ninja package was not found. WKV kernel module can't be loaded "
  198. "for training. Please, 'pip install ninja' in your environment."
  199. )
  200. if not torch.cuda.is_available():
  201. raise ImportError(
  202. "CUDA is currently a requirement for WKV kernel loading. "
  203. "Please set your devices properly and launch again."
  204. )
  205. kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
  206. kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
  207. kernel_cflags = [
  208. "-res-usage",
  209. "--maxrregcount 60",
  210. "--use_fast_math",
  211. "-O3",
  212. "-Xptxas -O3",
  213. f"-DTmax={context_size}",
  214. ]
  215. wkv_kernel_encoder = load(
  216. name=f"encoder_wkv_{context_size}",
  217. sources=kernel_files,
  218. verbose=True,
  219. extra_cuda_cflags=kernel_cflags,
  220. )
  221. wkv_kernel_encoder.context_size = context_size
  222. def load_decoder_wkv_kernel(context_size: int) -> None:
  223. """Load WKV CUDA kernel.
  224. Args:
  225. context_size: Context size.
  226. """
  227. from torch.utils.cpp_extension import load
  228. global wkv_kernel_decoder
  229. if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
  230. return
  231. if find_spec("ninja") is None:
  232. raise ImportError(
  233. "Ninja package was not found. WKV kernel module can't be loaded "
  234. "for training. Please, 'pip install ninja' in your environment."
  235. )
  236. if not torch.cuda.is_available():
  237. raise ImportError(
  238. "CUDA is currently a requirement for WKV kernel loading. "
  239. "Please set your devices properly and launch again."
  240. )
  241. kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
  242. kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
  243. kernel_cflags = [
  244. "-res-usage",
  245. "--maxrregcount 60",
  246. "--use_fast_math",
  247. "-O3",
  248. "-Xptxas -O3",
  249. f"-DTmax={context_size}",
  250. ]
  251. wkv_kernel_decoder = load(
  252. name=f"decoder_wkv_{context_size}",
  253. sources=kernel_files,
  254. verbose=True,
  255. extra_cuda_cflags=kernel_cflags,
  256. )
  257. wkv_kernel_decoder.context_size = context_size
  258. class SelfAttention(torch.nn.Module):
  259. """SelfAttention module definition.
  260. Args:
  261. size: Input/Output size.
  262. attention_size: Attention hidden size.
  263. context_size: Context size for WKV kernel.
  264. block_id: Block index.
  265. num_blocks: Number of blocks in the architecture.
  266. """
  267. def __init__(
  268. self,
  269. size: int,
  270. attention_size: int,
  271. block_id: int,
  272. dropout_rate: float,
  273. num_blocks: int,
  274. ) -> None:
  275. """Construct a SelfAttention object."""
  276. super().__init__()
  277. self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
  278. self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
  279. self.time_first = torch.nn.Parameter(torch.empty(attention_size))
  280. self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
  281. self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
  282. self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
  283. self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
  284. self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
  285. self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
  286. self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
  287. self.block_id = block_id
  288. self.reset_parameters(size, attention_size, block_id, num_blocks)
  289. self.dropout = torch.nn.Dropout(p=dropout_rate)
  290. def reset_parameters(
  291. self, size: int, attention_size: int, block_id: int, num_blocks: int
  292. ) -> None:
  293. """Reset module parameters.
  294. Args:
  295. size: Block size.
  296. attention_size: Attention hidden size.
  297. block_id: Block index.
  298. num_blocks: Number of blocks in the architecture.
  299. """
  300. ratio_0_to_1 = block_id / (num_blocks - 1)
  301. ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
  302. time_weight = torch.ones(1, 1, size)
  303. for i in range(size):
  304. time_weight[0, 0, i] = i / size
  305. decay_speed = [
  306. -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
  307. for h in range(attention_size)
  308. ]
  309. decay_speed = torch.tensor(
  310. decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
  311. )
  312. zigzag = (
  313. torch.tensor(
  314. [(i + 1) % 3 - 1 for i in range(attention_size)],
  315. dtype=self.time_first.dtype,
  316. device=self.time_first.device,
  317. )
  318. * 0.5
  319. )
  320. with torch.no_grad():
  321. self.time_decay.data = decay_speed
  322. self.time_first.data = torch.ones_like(
  323. self.time_first * math.log(0.3) + zigzag
  324. )
  325. self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
  326. self.time_mix_value.data = (
  327. torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
  328. )
  329. self.time_mix_receptance.data = torch.pow(
  330. time_weight, 0.5 * ratio_1_to_almost0
  331. )
  332. @torch.no_grad()
  333. def wkv_linear_attention(
  334. self,
  335. time_decay: torch.Tensor,
  336. time_first: torch.Tensor,
  337. key: torch.Tensor,
  338. value: torch.Tensor,
  339. state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
  340. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  341. """Compute WKV with state (i.e.: for inference).
  342. Args:
  343. time_decay: Channel-wise time decay vector. (D_att)
  344. time_first: Channel-wise time first vector. (D_att)
  345. key: Key tensor. (B, 1, D_att)
  346. value: Value tensor. (B, 1, D_att)
  347. state: Decoder hidden states. [3 x (B, D_att)]
  348. Returns:
  349. output: Weighted Key-Value. (B, 1, D_att)
  350. state: Decoder hidden states. [3 x (B, 1, D_att)]
  351. """
  352. num_state, den_state, max_state = state
  353. time_decay = -torch.exp(time_decay)
  354. max_for_output = torch.maximum(max_state, (time_first + key))
  355. e1 = torch.exp(max_state - max_for_output)
  356. e2 = torch.exp((time_first + key) - max_for_output)
  357. numerator = e1 * num_state + e2 * value
  358. denominator = e1 * den_state + e2
  359. max_for_state = torch.maximum(key, (max_state + time_decay))
  360. e1 = torch.exp((max_state + time_decay) - max_for_state)
  361. e2 = torch.exp(key - max_for_state)
  362. wkv = numerator / denominator
  363. state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
  364. return wkv, state
  365. class DecoderSelfAttention(SelfAttention):
  366. """SelfAttention module definition.
  367. Args:
  368. size: Input/Output size.
  369. attention_size: Attention hidden size.
  370. context_size: Context size for WKV kernel.
  371. block_id: Block index.
  372. num_blocks: Number of blocks in the architecture.
  373. """
  374. def __init__(
  375. self,
  376. size: int,
  377. attention_size: int,
  378. context_size: int,
  379. block_id: int,
  380. dropout_rate: float,
  381. num_blocks: int,
  382. ) -> None:
  383. """Construct a SelfAttention object."""
  384. super().__init__(
  385. size,
  386. attention_size,
  387. block_id,
  388. dropout_rate,
  389. num_blocks
  390. )
  391. # load_decoder_wkv_kernel(context_size)
  392. def forward(
  393. self,
  394. x: torch.Tensor,
  395. state: Optional[List[torch.Tensor]] = None,
  396. ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
  397. """Compute time mixing.
  398. Args:
  399. x: SelfAttention input sequences. (B, U, size)
  400. state: Decoder hidden states. [5 x (B, 1, D_att, N)]
  401. Returns:
  402. x: SelfAttention output sequences. (B, U, size)
  403. """
  404. shifted_x = (
  405. self.time_shift(x) if state is None else state[1][..., self.block_id]
  406. )
  407. key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
  408. value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
  409. receptance = x * self.time_mix_receptance + shifted_x * (
  410. 1 - self.time_mix_receptance
  411. )
  412. key = self.proj_key(key)
  413. value = self.proj_value(value)
  414. receptance = torch.sigmoid(self.proj_receptance(receptance))
  415. if state is not None:
  416. state[1][..., self.block_id] = x
  417. wkv, att_state = self.wkv_linear_attention(
  418. self.time_decay,
  419. self.time_first,
  420. key,
  421. value,
  422. tuple(s[..., self.block_id] for s in state[2:]),
  423. )
  424. state[2][..., self.block_id] = att_state[0]
  425. state[3][..., self.block_id] = att_state[1]
  426. state[4][..., self.block_id] = att_state[2]
  427. else:
  428. wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
  429. wkv = self.dropout(wkv)
  430. x = self.proj_output(receptance * wkv)
  431. return x, state
  432. class EncoderSelfAttention(SelfAttention):
  433. """SelfAttention module definition.
  434. Args:
  435. size: Input/Output size.
  436. attention_size: Attention hidden size.
  437. context_size: Context size for WKV kernel.
  438. block_id: Block index.
  439. num_blocks: Number of blocks in the architecture.
  440. """
  441. def __init__(
  442. self,
  443. size: int,
  444. attention_size: int,
  445. context_size: int,
  446. block_id: int,
  447. dropout_rate: float,
  448. num_blocks: int,
  449. ) -> None:
  450. """Construct a SelfAttention object."""
  451. super().__init__(
  452. size,
  453. attention_size,
  454. block_id,
  455. dropout_rate,
  456. num_blocks
  457. )
  458. # load_encoder_wkv_kernel(context_size)
  459. def forward(
  460. self,
  461. x: torch.Tensor,
  462. state: Optional[List[torch.Tensor]] = None,
  463. ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
  464. """Compute time mixing.
  465. Args:
  466. x: SelfAttention input sequences. (B, U, size)
  467. state: Decoder hidden states. [5 x (B, 1, D_att, N)]
  468. Returns:
  469. x: SelfAttention output sequences. (B, U, size)
  470. """
  471. shifted_x = (
  472. self.time_shift(x) if state is None else state[1][..., self.block_id]
  473. )
  474. key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
  475. value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
  476. receptance = x * self.time_mix_receptance + shifted_x * (
  477. 1 - self.time_mix_receptance
  478. )
  479. key = self.proj_key(key)
  480. value = self.proj_value(value)
  481. receptance = torch.sigmoid(self.proj_receptance(receptance))
  482. if state is not None:
  483. state[1][..., self.block_id] = x
  484. wkv, att_state = self.wkv_linear_attention(
  485. self.time_decay,
  486. self.time_first,
  487. key,
  488. value,
  489. tuple(s[..., self.block_id] for s in state[2:]),
  490. )
  491. state[2][..., self.block_id] = att_state[0]
  492. state[3][..., self.block_id] = att_state[1]
  493. state[4][..., self.block_id] = att_state[2]
  494. else:
  495. wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
  496. wkv = self.dropout(wkv)
  497. x = self.proj_output(receptance * wkv)
  498. return x, state