| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- // Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu
- #include <stdio.h>
- #include <assert.h>
- #define MIN_VALUE (-1e38)
- template <typename F>
- __global__ void kernel_forward(const int B, const int T, const int C,
- const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
- F *__restrict__ const _y) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- const int _b = idx / C;
- const int _c = idx % C;
- const int _offset = _b * T * C + _c;
- F u = _u[_c];
- F w = _w[_c];
- const F *__restrict__ const k = _k + _offset;
- const F *__restrict__ const v = _v + _offset;
- F *__restrict__ const y = _y + _offset;
- // aa and bb are running sums divided by exp(pp) (to avoid overflow)
- F aa = 0, bb = 0, pp = MIN_VALUE;
- for (int i = 0; i < T; i++) {
- const int ii = i * C;
- const F kk = k[ii];
- const F vv = v[ii];
- F ww = u + kk;
- F p = max(pp, ww);
- F e1 = exp(pp - p);
- F e2 = exp(ww - p);
- y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
- ww = w + pp;
- p = max(ww, kk);
- e1 = exp(ww - p);
- e2 = exp(kk - p);
- aa = e1 * aa + e2 * vv;
- bb = e1 * bb + e2;
- pp = p;
- }
- }
- template <typename F>
- __global__ void kernel_backward(const int B, const int T, const int C,
- const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
- const F *__restrict__ const _y, const F *__restrict__ const _gy,
- F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- const int _b = idx / C;
- const int _c = idx % C;
- const int _offset = _b * T * C + _c;
- F u = _u[_c];
- F w = _w[_c];
- const F *__restrict__ const k = _k + _offset;
- const F *__restrict__ const v = _v + _offset;
- const F *__restrict__ const y = _y + _offset;
- const F *__restrict__ const gy = _gy + _offset;
- F *__restrict__ const gk = _gk + _offset;
- F *__restrict__ const gv = _gv + _offset;
- F q[Tmax], r[Tmax];
- F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
- for (int i = 0; i < T; i++) {
- const int ii = i * C;
- const F kk = k[ii];
- const F vv = v[ii];
- const F yy = y[ii];
- F ww = u + kk;
- F p = max(pp, ww);
- F e1 = exp(pp - p);
- F e2 = exp(ww - p);
- const F qq = gy[ii] / (e1 * bb + e2);
- gw += (ga - gb * yy) * e1 * qq;
- gu += (vv - yy) * e2 * qq;
- q[i] = qq;
- r[i] = ww - p;
- ww = w + pp;
- p = max(ww, kk);
- e1 = exp(ww - p);
- e2 = exp(kk - p);
- ga = e1 * (aa + ga);
- gb = e1 * (bb + gb);
- aa = e1 * aa + e2 * vv;
- bb = e1 * bb + e2;
- pp = p;
- }
- const int _offsetBC = _b * C + _c;
- _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
- _gu[_offsetBC] = gu;
- aa = 0, bb = 0, pp = MIN_VALUE;
- for (int i = T - 1; i >= 0; i--) {
- const int ii = i * C;
- const F kk = k[ii];
- const F vv = v[ii];
- const F yy = y[ii];
- const F qq = q[i];
- const F rr = r[i];
- F e1 = qq * exp(rr);
- F e2 = exp(kk + pp);
- gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
- gv[ii] = e1 + e2 * aa;
- const F ww = w + pp;
- const F www = rr - u - kk;
- const F p = max(ww, www);
- e1 = exp(ww - p);
- e2 = qq * exp(www - p);
- aa = e1 * aa + e2;
- bb = e1 * bb - e2 * yy;
- pp = p;
- }
- }
- void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
- dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
- assert(B * C % threadsPerBlock.x == 0);
- dim3 numBlocks(B * C / threadsPerBlock.x);
- kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
- }
- void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
- dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
- assert(B * C % threadsPerBlock.x == 0);
- dim3 numBlocks(B * C / threadsPerBlock.x);
- kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
- }
|