Tensor.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #ifndef TENSOR_H
  2. #define TENSOR_H
  3. #include "alignedmem.h"
  4. using namespace std;
  5. template <typename T> class Tensor {
  6. private:
  7. void alloc_buff();
  8. void free_buff();
  9. int mem_size;
  10. public:
  11. T *buff;
  12. int size[4];
  13. int buff_size;
  14. Tensor(Tensor<T> *in);
  15. Tensor(int a);
  16. Tensor(int a, int b);
  17. Tensor(int a, int b, int c);
  18. Tensor(int a, int b, int c, int d);
  19. ~Tensor();
  20. void zeros();
  21. void shape();
  22. void disp();
  23. void dump(const char *mode);
  24. void concat(Tensor<T> *din, int dim);
  25. void resize(int a, int b, int c, int d);
  26. void add(float coe, Tensor<T> *in);
  27. void add(Tensor<T> *in);
  28. void add(Tensor<T> *in1, Tensor<T> *in2);
  29. void reload(Tensor<T> *in);
  30. };
  31. template <typename T> Tensor<T>::Tensor(int a) : size{1, 1, 1, a}
  32. {
  33. alloc_buff();
  34. }
  35. template <typename T> Tensor<T>::Tensor(int a, int b) : size{1, 1, a, b}
  36. {
  37. alloc_buff();
  38. }
  39. template <typename T> Tensor<T>::Tensor(int a, int b, int c) : size{1, a, b, c}
  40. {
  41. alloc_buff();
  42. }
  43. template <typename T>
  44. Tensor<T>::Tensor(int a, int b, int c, int d) : size{a, b, c, d}
  45. {
  46. alloc_buff();
  47. }
  48. template <typename T> Tensor<T>::Tensor(Tensor<T> *in)
  49. {
  50. memcpy(size, in->size, 4 * sizeof(int));
  51. alloc_buff();
  52. memcpy(buff, in->buff, in->buff_size * sizeof(T));
  53. }
  54. template <typename T> Tensor<T>::~Tensor()
  55. {
  56. free_buff();
  57. }
  58. template <typename T> void Tensor<T>::alloc_buff()
  59. {
  60. buff_size = size[0] * size[1] * size[2] * size[3];
  61. mem_size = buff_size;
  62. buff = (T *)aligned_malloc(32, buff_size * sizeof(T));
  63. }
  64. template <typename T> void Tensor<T>::free_buff()
  65. {
  66. aligned_free(buff);
  67. }
  68. template <typename T> void Tensor<T>::zeros()
  69. {
  70. memset(buff, 0, buff_size * sizeof(T));
  71. }
  72. template <typename T> void Tensor<T>::shape()
  73. {
  74. printf("(%d,%d,%d,%d)\n", size[0], size[1], size[2], size[3]);
  75. }
  76. // TODO:: fix it!!!!
  77. template <typename T> void Tensor<T>::concat(Tensor<T> *din, int dim)
  78. {
  79. memcpy(buff + buff_size, din->buff, din->buff_size * sizeof(T));
  80. buff_size += din->buff_size;
  81. size[dim] += din->size[dim];
  82. }
  83. // TODO:: fix it!!!!
  84. template <typename T> void Tensor<T>::resize(int a, int b, int c, int d)
  85. {
  86. size[0] = a;
  87. size[1] = b;
  88. size[2] = c;
  89. size[3] = d;
  90. buff_size = size[0] * size[1] * size[2] * size[3];
  91. }
  92. template <typename T> void Tensor<T>::add(float coe, Tensor<T> *in)
  93. {
  94. int i;
  95. for (i = 0; i < buff_size; i++) {
  96. buff[i] = buff[i] + coe * in->buff[i];
  97. }
  98. }
  99. template <typename T> void Tensor<T>::add(Tensor<T> *in)
  100. {
  101. int i;
  102. for (i = 0; i < buff_size; i++) {
  103. buff[i] = buff[i] + in->buff[i];
  104. }
  105. }
  106. template <typename T> void Tensor<T>::add(Tensor<T> *in1, Tensor<T> *in2)
  107. {
  108. int i;
  109. for (i = 0; i < buff_size; i++) {
  110. buff[i] = buff[i] + in1->buff[i] + in2->buff[i];
  111. }
  112. }
  113. template <typename T> void Tensor<T>::reload(Tensor<T> *in)
  114. {
  115. memcpy(buff, in->buff, in->buff_size * sizeof(T));
  116. }
  117. template <typename T> void Tensor<T>::disp()
  118. {
  119. int i;
  120. for (i = 0; i < buff_size; i++) {
  121. cout << buff[i] << " ";
  122. }
  123. cout << endl;
  124. }
  125. template <typename T> void Tensor<T>::dump(const char *mode)
  126. {
  127. FILE *fp;
  128. fp = fopen("tmp.bin", mode);
  129. fwrite(buff, 1, buff_size * sizeof(T), fp);
  130. fclose(fp);
  131. }
  132. #endif