wkv_op.cpp 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. /*
  2. * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp
  3. Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp
  4. */
  5. #include <torch/extension.h>
  6. void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
  7. void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
  8. void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
  9. const int B = k.size(0);
  10. const int T = k.size(1);
  11. const int C = k.size(2);
  12. cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
  13. }
  14. void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
  15. const int B = k.size(0);
  16. const int T = k.size(1);
  17. const int C = k.size(2);
  18. cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
  19. }
  20. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  21. m.def("forward", &forward, "wkv forward");
  22. m.def("backward", &backward, "wkv backward");
  23. }
  24. TORCH_LIBRARY(wkv_encoder, m) {
  25. m.def("forward", forward);
  26. m.def("backward", backward);
  27. }