SparseCsrTensorUtils.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/SparseCsrTensorImpl.h>
  4. #include <ATen/SparseTensorImpl.h>
  5. #include <ATen/core/Tensor.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #include <ATen/NativeFunctions.h>
  9. #include <ATen/Operators.h>
  10. #else
  11. #include <ATen/ops/_sparse_compressed_tensor_unsafe.h>
  12. #include <ATen/ops/resize_as_sparse_native.h>
  13. #endif
  14. #define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \
  15. [&] { \
  16. const auto& the_layout = LAYOUT; \
  17. switch (the_layout) { \
  18. case kSparseCsr: \
  19. case kSparseCsc: \
  20. case kSparseBsr: \
  21. case kSparseBsc: \
  22. return __VA_ARGS__(); \
  23. default: \
  24. TORCH_CHECK( \
  25. false, \
  26. NAME, \
  27. " expected sparse compressed tensor layout but got ", \
  28. the_layout); \
  29. } \
  30. }()
  31. #define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \
  32. LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \
  33. [&]() { \
  34. const auto& the_layout = LAYOUT; \
  35. switch (the_layout) { \
  36. case kSparseCsr: \
  37. case kSparseBsr: \
  38. return (ROW_DIM_ACTION)(); \
  39. case kSparseCsc: \
  40. case kSparseBsc: \
  41. return (COLUMN_DIM_ACTION)(); \
  42. default: \
  43. TORCH_CHECK( \
  44. false, \
  45. NAME, \
  46. " expected sparse compressed tensor layout but got ", \
  47. the_layout); \
  48. } \
  49. }()
  50. #define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \
  51. LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \
  52. [&]() { \
  53. const auto& the_layout = LAYOUT; \
  54. switch (the_layout) { \
  55. case kSparseCsr: \
  56. case kSparseCsc: \
  57. return (NO_BLOCK_ACTION)(); \
  58. case kSparseBsr: \
  59. case kSparseBsc: \
  60. return (BLOCK_ACTION)(); \
  61. default: \
  62. TORCH_CHECK( \
  63. false, \
  64. NAME, \
  65. " expected sparse compressed tensor layout but got ", \
  66. the_layout); \
  67. } \
  68. }()
  69. #define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \
  70. LAYOUT, NAME, ROW_DIM_ACTION) \
  71. [&]() { \
  72. const auto& the_layout = LAYOUT; \
  73. switch (the_layout) { \
  74. case kSparseCsr: \
  75. case kSparseBsr: \
  76. return (ROW_DIM_ACTION)(); \
  77. default: \
  78. TORCH_CHECK( \
  79. false, \
  80. NAME, \
  81. " expected sparse row compressed tensor layout but got ", \
  82. the_layout); \
  83. } \
  84. }()
  85. #define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \
  86. LAYOUT, NAME, COL_DIM_ACTION) \
  87. [&]() { \
  88. const auto& the_layout = LAYOUT; \
  89. switch (the_layout) { \
  90. case kSparseCsc: \
  91. case kSparseBsc: \
  92. return (COL_DIM_ACTION)(); \
  93. default: \
  94. TORCH_CHECK( \
  95. false, \
  96. NAME, \
  97. " expected sparse column compressed tensor layout but got ", \
  98. the_layout); \
  99. } \
  100. }()
  101. #define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  102. [&]() { \
  103. const auto& the_layout = LAYOUT; \
  104. switch (the_layout) { \
  105. case kSparseCsr: \
  106. case kSparseCsc: \
  107. return (ACTION)(); \
  108. default: \
  109. TORCH_CHECK( \
  110. false, \
  111. NAME, \
  112. " expected sparse compressed (non-block) tensor layout but got ", \
  113. the_layout); \
  114. } \
  115. }()
  116. #define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \
  117. [&]() { \
  118. const auto& the_layout = LAYOUT; \
  119. switch (the_layout) { \
  120. case kSparseBsr: \
  121. case kSparseBsc: \
  122. return (ACTION)(); \
  123. default: \
  124. TORCH_CHECK( \
  125. false, \
  126. NAME, \
  127. " expected sparse compressed block tensor layout but got ", \
  128. the_layout); \
  129. } \
  130. }()
  131. #define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \
  132. AT_DISPATCH_SWITCH( \
  133. TYPE, \
  134. NAME, \
  135. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  136. kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__))
  137. namespace at::sparse_csr {
  138. // Implements RAII object to manage checking sparse tensor invariants:
  139. class CheckSparseTensorInvariants {
  140. std::optional<bool> old_state;
  141. public:
  142. CheckSparseTensorInvariants(bool state)
  143. : old_state(at::globalContext().checkSparseTensorInvariants()) {
  144. at::globalContext().setCheckSparseTensorInvariants(state);
  145. }
  146. CheckSparseTensorInvariants(CheckSparseTensorInvariants&& other) = delete;
  147. CheckSparseTensorInvariants(const CheckSparseTensorInvariants&) = delete;
  148. CheckSparseTensorInvariants& operator=(const CheckSparseTensorInvariants&) =
  149. delete;
  150. CheckSparseTensorInvariants& operator=(CheckSparseTensorInvariants&&) =
  151. delete;
  152. ~CheckSparseTensorInvariants() {
  153. at::globalContext().setCheckSparseTensorInvariants(old_state);
  154. }
  155. };
  156. using SparseCsrTensor = Tensor;
  157. inline bool is_sparse_compressed(const Layout& layout) {
  158. switch (layout) {
  159. case kSparseCsr:
  160. case kSparseCsc:
  161. case kSparseBsr:
  162. case kSparseBsc:
  163. return true;
  164. default:;
  165. }
  166. return false;
  167. }
  168. inline bool is_sparse_compressed(const Tensor& self) {
  169. return is_sparse_compressed(self.layout());
  170. }
  171. inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) {
  172. AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(
  173. self.layout(), "get_sparse_csr_impl", [&] {});
  174. return static_cast<SparseCsrTensorImpl*>(self.unsafeGetTensorImpl());
  175. }
  176. inline std::string layoutToString(
  177. Layout layout,
  178. bool upper = false,
  179. bool lower = false) {
  180. switch (layout) {
  181. case kSparseCsr:
  182. return (upper ? "CSR" : (lower ? "csr" : "Csr"));
  183. case kSparseCsc:
  184. return (upper ? "CSC" : (lower ? "csc" : "Csc"));
  185. case kSparseBsr:
  186. return (upper ? "BSR" : (lower ? "bsr" : "Bsr"));
  187. case kSparseBsc:
  188. return (upper ? "BSC" : (lower ? "bsc" : "Bsc"));
  189. default:
  190. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  191. return "";
  192. }
  193. }
  194. inline bool isCompressedRow(Layout layout) {
  195. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  196. layout, "isCompressedRow", [&] { return true; }, [&] { return false; });
  197. }
  198. inline bool isCompressedColumn(Layout layout) {
  199. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  200. layout,
  201. "isCompressedColumn",
  202. [&] { return false; },
  203. [&] { return true; });
  204. }
  205. inline std::string compressedIndicesName(Layout layout) {
  206. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  207. layout,
  208. "compressedIndicesName",
  209. [&] { return "crow_indices"; },
  210. [&] { return "ccol_indices"; });
  211. }
  212. inline std::string plainIndicesName(Layout layout) {
  213. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  214. layout,
  215. "plainIndicesName",
  216. [&] { return "col_indices"; },
  217. [&] { return "row_indices"; });
  218. }
  219. inline std::string compressedDimName(Layout layout) {
  220. switch (layout) {
  221. case kSparseCsr:
  222. return "row";
  223. case kSparseCsc:
  224. return "column";
  225. case kSparseBsr:
  226. return "row block";
  227. case kSparseBsc:
  228. return "column block";
  229. default:
  230. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  231. return "";
  232. }
  233. }
  234. inline std::string plainDimName(Layout layout) {
  235. switch (layout) {
  236. case kSparseCsr:
  237. return "column";
  238. case kSparseCsc:
  239. return "row";
  240. case kSparseBsr:
  241. return "column block";
  242. case kSparseBsc:
  243. return "row block";
  244. default:
  245. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  246. return "";
  247. }
  248. }
  249. inline size_t rowDimension(Layout layout, IntArrayRef size) {
  250. return size.size() - (isCompressedRow(layout) ? 2 : 1);
  251. }
  252. inline size_t columnDimension(Layout layout, IntArrayRef size) {
  253. return size.size() - (isCompressedColumn(layout) ? 2 : 1);
  254. }
  255. inline size_t compressedDimension(
  256. Layout layout,
  257. IntArrayRef size,
  258. size_t dense_ndim = 0) {
  259. return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1);
  260. }
  261. inline size_t plainDimension(
  262. Layout layout,
  263. IntArrayRef size,
  264. size_t dense_ndim = 0) {
  265. return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2);
  266. }
  267. inline int64_t numBatchDimensions(Tensor const& self) {
  268. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  269. self.layout(),
  270. "numBatchDimensions",
  271. [&self] { return self.crow_indices().dim() - 1; },
  272. [&self] { return self.ccol_indices().dim() - 1; });
  273. }
  274. inline std::pair<Tensor, Tensor> getCompressedPlainIndices(Tensor const& self) {
  275. return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
  276. self.layout(),
  277. "getCompressedPlainIndices",
  278. [&self] {
  279. return std::make_pair(self.crow_indices(), self.col_indices());
  280. },
  281. [&self] {
  282. return std::make_pair(self.ccol_indices(), self.row_indices());
  283. });
  284. }
  285. inline ScalarType getIndexDtype(Tensor const& self) {
  286. switch (self.layout()) {
  287. case kSparseCsr:
  288. case kSparseBsr:
  289. return self.crow_indices().scalar_type();
  290. case kSparseCsc:
  291. case kSparseBsc:
  292. return self.ccol_indices().scalar_type();
  293. case kSparse:
  294. return self._indices().scalar_type();
  295. default:
  296. return ScalarType::Long;
  297. }
  298. }
  299. inline Layout flip_compressed_layout(Layout layout) {
  300. switch (layout) {
  301. case kSparseCsr:
  302. return kSparseCsc;
  303. case kSparseCsc:
  304. return kSparseCsr;
  305. case kSparseBsr:
  306. return kSparseBsc;
  307. case kSparseBsc:
  308. return kSparseBsr;
  309. default:
  310. TORCH_CHECK(false, "Not a sparse compressed layout:", layout);
  311. return kSparseCsr;
  312. }
  313. }
  314. inline DimVector getBlockSize(Tensor const& self) {
  315. int64_t n_batch = numBatchDimensions(self);
  316. return at::DimVector(self.values().sizes().slice(n_batch + 1, 2));
  317. }
  318. inline at::OptionalArray<at::SymInt> getSymIntBlockSize(Tensor const& self) {
  319. if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) {
  320. int64_t n_batch = numBatchDimensions(self);
  321. return self.values().sym_sizes().slice(n_batch + 1, 2).vec();
  322. } else {
  323. return {};
  324. }
  325. }
  326. template <typename binary_op_t, typename binary_op_out_t>
  327. inline bool only_sparse_compressed_binary_op_trivial_cases(
  328. const Tensor& self,
  329. const Tensor& other,
  330. const Scalar& alpha,
  331. Tensor& out,
  332. const binary_op_t& binary_op,
  333. const binary_op_out_t& binary_op_out) {
  334. // Only sparse compressed! Just like the name says :)
  335. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self));
  336. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other));
  337. TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out));
  338. // Bypass BLAS if there are matches in (self, other, out)
  339. if (self.is_same(out) && self.is_same(other)) {
  340. binary_op_out(self.values(), other.values(), alpha);
  341. return true;
  342. }
  343. if (self.is_same(other)) {
  344. auto [compressed_indices, plain_indices] =
  345. at::sparse_csr::getCompressedPlainIndices(self);
  346. static_cast<SparseCsrTensorImpl*>(out.unsafeGetTensorImpl())
  347. ->set_member_tensors(
  348. compressed_indices,
  349. plain_indices,
  350. binary_op(self.values(), other.values(), alpha),
  351. self.sizes());
  352. return true;
  353. }
  354. return false;
  355. }
  356. inline bool only_sparse_compressed_add_trivial_cases(
  357. const Tensor& self,
  358. const Tensor& other,
  359. const Scalar& alpha,
  360. Tensor& out) {
  361. return only_sparse_compressed_binary_op_trivial_cases(
  362. self,
  363. other,
  364. alpha,
  365. out,
  366. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  367. return v1.add(v2, alpha);
  368. },
  369. [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) {
  370. return v1.add_(v2, alpha);
  371. });
  372. }
  373. inline Tensor to_type(const Tensor& input, ScalarType dtype) {
  374. auto [compressed_indices, plain_indices] =
  375. at::sparse_csr::getCompressedPlainIndices(input);
  376. return at::_sparse_compressed_tensor_unsafe(
  377. compressed_indices,
  378. plain_indices,
  379. std::move(input.values()).to(dtype),
  380. input.sizes(),
  381. dtype,
  382. input.layout(),
  383. input.device(),
  384. input.options().pinned_memory_opt());
  385. }
  386. template <typename acc_t, typename scalar_t>
  387. inline std::tuple<Tensor, Tensor> create_acc_buffer(
  388. TensorOptions option,
  389. ScalarType type,
  390. int64_t nnz = -1) {
  391. Tensor new_values, new_values_acc;
  392. constexpr bool need_acc = !std::is_same_v<scalar_t, acc_t>;
  393. bool is_integral = at::isIntegralType(type, /*includeBool=*/true);
  394. if constexpr (need_acc) {
  395. auto acc_dtype = CppTypeToScalarType<acc_t>::value;
  396. new_values_acc = at::empty({}, option.dtype(acc_dtype));
  397. new_values = is_integral ? new_values_acc : at::empty({}, option);
  398. } else {
  399. new_values = new_values_acc = at::empty({}, option);
  400. }
  401. if (nnz != -1) {
  402. return std::make_tuple(
  403. new_values.resize_(nnz), new_values_acc.resize_(nnz));
  404. } else {
  405. return std::make_tuple(new_values, new_values_acc);
  406. }
  407. }
  408. inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) {
  409. if (!new_values_acc.is_same(new_values)) {
  410. new_values.copy_(new_values_acc);
  411. }
  412. }
  413. } // namespace at::sparse_csr
  414. #else
  415. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  416. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)