CPUApplyUtils.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/CollapseDims.h>
  4. #include <ATen/Parallel.h>
  5. #include <ATen/TensorUtils.h>
  6. #include <c10/util/irange.h>
  7. #include <cstring>
  8. #include <limits>
  9. namespace at {
  10. /*
  11. * The basic strategy for apply is as follows:
  12. *
  13. * 1. Starting with the outermost index, loop until we reach a dimension where
  14. * the data is no longer contiguous, i.e. the stride at that dimension is not
  15. * equal to the size of the tensor defined by the outer dimensions. Let's call
  16. * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
  17. * A is equal to the entire Tensor. Let's call the inner tensor B.
  18. *
  19. * 2. We loop through the indices in B, starting at its outermost dimension. For
  20. * example, if B is a 2x2 matrix, then we do:
  21. *
  22. * B[0][0]
  23. * B[0][1]
  24. * B[1][0]
  25. * B[1][1]
  26. *
  27. * We set the offset into the underlying storage as (storageOffset + stride_B *
  28. * index_B), i.e. basically we compute the offset into the storage as we would
  29. * normally for a Tensor. But because we are guaranteed the subsequent data is
  30. * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
  31. * the operation, without having to follow the order described by the strides of
  32. * A.
  33. *
  34. * 3. As an optimization, we merge dimensions of A that are contiguous in
  35. * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
  36. * then the first two dimensions can be merged for the purposes of APPLY,
  37. * reducing the number of nested loops.
  38. */
  39. inline Tensor sort_strides(Tensor& tensor_) {
  40. IntArrayRef strides = tensor_.strides();
  41. std::vector<int64_t> indices;
  42. indices.reserve(tensor_.ndimension());
  43. for (const auto i : c10::irange(tensor_.ndimension())) {
  44. indices.push_back(i);
  45. }
  46. std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
  47. return strides[i1] > strides[i2];
  48. });
  49. Tensor tensor = tensor_.permute(indices);
  50. return tensor;
  51. }
  52. template <typename T, int N>
  53. struct strided_tensor_iter_fixed {
  54. public:
  55. T* data_ = NULL;
  56. int64_t dim_ = 0;
  57. // NOLINTNEXTLINE(*array*)
  58. int64_t counter_[N] = {0};
  59. // NOLINTNEXTLINE(*array*)
  60. int64_t sizes_[N] = {0};
  61. // NOLINTNEXTLINE(*array*)
  62. int64_t strides_[N] = {0};
  63. strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
  64. strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) =
  65. delete;
  66. strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default;
  67. strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept =
  68. default;
  69. ~strided_tensor_iter_fixed() noexcept = default;
  70. strided_tensor_iter_fixed(
  71. Tensor& tensor,
  72. [[maybe_unused]] bool sort_strides = false)
  73. : data_(tensor.data_ptr<T>()) {
  74. std::memset(counter_, 0, sizeof(int64_t) * N);
  75. if (tensor.dim() > 0) {
  76. std::memcpy(
  77. sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
  78. std::memcpy(
  79. strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
  80. }
  81. dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
  82. }
  83. };
  84. template <typename T>
  85. struct strided_tensor_iter {
  86. private:
  87. public:
  88. T* data_ = NULL;
  89. int64_t dim_;
  90. std::vector<int64_t> counter_;
  91. std::vector<int64_t> sizes_;
  92. std::vector<int64_t> strides_;
  93. strided_tensor_iter(strided_tensor_iter const&) = delete;
  94. strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete;
  95. strided_tensor_iter(strided_tensor_iter&&) noexcept = default;
  96. strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default;
  97. ~strided_tensor_iter() noexcept = default;
  98. strided_tensor_iter(Tensor& tensor)
  99. : data_(tensor.data_ptr<T>()),
  100. dim_(tensor.ndimension()),
  101. counter_(dim_, 0),
  102. sizes_(tensor.sizes().vec()),
  103. strides_(tensor.strides().vec()) {
  104. dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
  105. }
  106. };
  107. inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
  108. if (tensors.empty())
  109. return true;
  110. int64_t all_numel = tensors[0].numel();
  111. for (const auto i : c10::irange(1, tensors.size())) {
  112. if (tensors[i].numel() != all_numel)
  113. return false;
  114. }
  115. return true;
  116. }
  117. inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
  118. std::ostringstream oss;
  119. oss << "inconsistent tensor size, expected ";
  120. for (size_t i = 0; i < tensors.size() - 1; i++) {
  121. oss << tensors[i].sizes() << ", ";
  122. }
  123. oss << "and " << tensors[tensors.size() - 1].sizes()
  124. << " to have the same number of elements, but got ";
  125. for (size_t i = 0; i < tensors.size() - 1; i++) {
  126. oss << tensors[i].numel() << ", ";
  127. }
  128. oss << "and " << tensors[tensors.size() - 1].numel()
  129. << " elements respectively";
  130. return oss.str();
  131. }
  132. inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
  133. checkDeviceType("CPU_tensor_apply", tensors, kCPU);
  134. checkLayout("CPU_tensor_apply", tensors, kStrided);
  135. TORCH_CHECK(_all_equal_numel(tensors), _all_equal_numel_error(tensors));
  136. // An empty tensor has no elements
  137. for (auto& t : tensors)
  138. if (t.numel() == 0)
  139. return false;
  140. return true;
  141. }
  142. inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
  143. int64_t dim = 0;
  144. for (auto& t : tensors)
  145. dim = std::max(dim, t.ndimension());
  146. return dim;
  147. }
  148. inline void iterate(int64_t /*size*/) {}
  149. template <typename Arg, typename... Args>
  150. inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
  151. iter.counter_[iter.dim_ - 1] += size;
  152. iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
  153. iterate(size, iter_tail...);
  154. }
  155. inline bool iterate_continue() {
  156. return true;
  157. }
  158. template <typename Arg, typename... Args>
  159. inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
  160. return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
  161. iterate_continue(iter_tail...);
  162. }
  163. inline int64_t max_iterate_size() {
  164. return std::numeric_limits<int64_t>::max();
  165. }
  166. template <typename Arg, typename... Args>
  167. inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
  168. return std::min(
  169. (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
  170. max_iterate_size(iter_tail...));
  171. }
  172. inline void iterate_overflow() {}
  173. template <typename Arg, typename... Args>
  174. inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
  175. if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
  176. for (int64_t i = iter.dim_ - 1; i > 0; i--) {
  177. if (iter.counter_[i] == iter.sizes_[i]) {
  178. iter.counter_[i] = 0;
  179. iter.counter_[i - 1]++;
  180. iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
  181. iter.strides_[i - 1];
  182. }
  183. }
  184. }
  185. iterate_overflow(iter_tail...);
  186. }
  187. inline void forward(int64_t /*offset*/) {}
  188. template <typename Arg, typename... Args>
  189. inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
  190. int64_t multi = offset;
  191. for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
  192. int64_t inc = multi % iter.sizes_[i];
  193. multi = multi / iter.sizes_[i];
  194. iter.data_ = iter.data_ + inc * iter.strides_[i];
  195. iter.counter_[i] += inc;
  196. }
  197. forward(offset, iter_tail...);
  198. }
  199. inline int64_t max_dim() {
  200. return 0;
  201. }
  202. template <typename Arg, typename... Args>
  203. inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
  204. return std::max(iter.dim_, max_dim(iter_tail...));
  205. }
  206. inline void apply_op() {}
  207. template <typename Op, typename... Args>
  208. inline void apply_op(
  209. int64_t numel,
  210. int64_t offset,
  211. const Op& op,
  212. Args... iters) {
  213. // For 0-dim tensors
  214. if (numel == 1 && max_dim(iters...) == 0) {
  215. op(*iters.data_...);
  216. return;
  217. }
  218. if (offset > 0)
  219. forward(offset, iters...);
  220. // Splitting this into chunks helps the compiler create faster assembly
  221. for (int64_t i = 0; i < numel;) {
  222. for (; iterate_continue(iters...) && i < numel;) {
  223. op(*iters.data_...);
  224. iterate(1, iters...);
  225. i++;
  226. }
  227. iterate_overflow(iters...);
  228. }
  229. }
  230. /*
  231. Apply a pointwise operator to sequence of tensors
  232. The calling convention for op is a function/functor that takes the same
  233. number of pointers of type scalar as the number of given tensors. For example,
  234. to compute a = b * c, op would be of the form:
  235. [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
  236. b_val[0] * c_val[0]; };
  237. */
  238. template <typename scalar1, typename scalar2, typename Op>
  239. inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
  240. if (!_apply_preamble({tensor1, tensor2}))
  241. return;
  242. if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
  243. apply_op(
  244. tensor1.numel(),
  245. 0,
  246. op,
  247. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  248. strided_tensor_iter_fixed<scalar2, 8>(tensor2));
  249. } else {
  250. apply_op(
  251. tensor1.numel(),
  252. 0,
  253. op,
  254. strided_tensor_iter<scalar1>(tensor1),
  255. strided_tensor_iter<scalar2>(tensor2));
  256. }
  257. }
  258. template <typename scalar1, typename scalar2, typename scalar3, typename Op>
  259. inline void CPU_tensor_apply3(
  260. Tensor tensor1,
  261. Tensor tensor2,
  262. Tensor tensor3,
  263. const Op op) {
  264. if (!_apply_preamble({tensor1, tensor2, tensor3}))
  265. return;
  266. if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
  267. apply_op(
  268. tensor1.numel(),
  269. 0,
  270. op,
  271. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  272. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  273. strided_tensor_iter_fixed<scalar3, 8>(tensor3));
  274. } else {
  275. apply_op(
  276. tensor1.numel(),
  277. 0,
  278. op,
  279. strided_tensor_iter<scalar1>(tensor1),
  280. strided_tensor_iter<scalar2>(tensor2),
  281. strided_tensor_iter<scalar3>(tensor3));
  282. }
  283. }
  284. template <
  285. typename scalar1,
  286. typename scalar2,
  287. typename scalar3,
  288. typename scalar4,
  289. typename Op>
  290. inline void CPU_tensor_apply4(
  291. Tensor tensor1,
  292. Tensor tensor2,
  293. Tensor tensor3,
  294. Tensor tensor4,
  295. const Op op) {
  296. if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
  297. return;
  298. if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
  299. apply_op(
  300. tensor1.numel(),
  301. 0,
  302. op,
  303. strided_tensor_iter_fixed<scalar1, 8>(tensor1),
  304. strided_tensor_iter_fixed<scalar2, 8>(tensor2),
  305. strided_tensor_iter_fixed<scalar3, 8>(tensor3),
  306. strided_tensor_iter_fixed<scalar4, 8>(tensor4));
  307. } else {
  308. apply_op(
  309. tensor1.numel(),
  310. 0,
  311. op,
  312. strided_tensor_iter<scalar1>(tensor1),
  313. strided_tensor_iter<scalar2>(tensor2),
  314. strided_tensor_iter<scalar3>(tensor3),
  315. strided_tensor_iter<scalar4>(tensor4));
  316. }
  317. }
  318. } // namespace at
  319. #else
  320. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  321. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)