SparseTensorImpl.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/core/impl/TorchDispatchModeTLS.h>
  6. #include <c10/util/Exception.h>
  7. #include <c10/util/irange.h>
  8. #ifndef AT_PER_OPERATOR_HEADERS
  9. #include <ATen/Functions.h>
  10. #else
  11. #include <ATen/ops/empty.h>
  12. #include <ATen/ops/resize.h>
  13. #endif
  14. namespace at {
  15. struct TORCH_API SparseTensorImpl : public TensorImpl {
  16. // Stored in COO format, indices + values.
  17. // INVARIANTS:
  18. // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
  19. // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape)
  20. // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz)
  21. // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz,
  22. // shape[sparse_dim:])
  23. int64_t sparse_dim_ = 0; // number of sparse dimensions
  24. int64_t dense_dim_ = 0; // number of dense dimensions
  25. Tensor indices_; // always a LongTensor
  26. Tensor values_;
  27. // A sparse tensor is 'coalesced' if every index occurs at most once in
  28. // the indices tensor, and the indices are in sorted order. (This means
  29. // that it is very easy to convert a coalesced tensor to CSR format: you
  30. // need only compute CSR format indices.)
  31. //
  32. // Most math operations can only be performed on coalesced sparse tensors,
  33. // because many algorithms proceed by merging two sorted lists (of indices).
  34. bool coalesced_ = false;
  35. // compute_numel with integer multiplication overflow check, see gh-57542
  36. void refresh_numel() {
  37. TensorImpl::safe_refresh_numel();
  38. }
  39. public:
  40. // Public for now...
  41. explicit SparseTensorImpl(
  42. at::DispatchKeySet /*key_set*/,
  43. const caffe2::TypeMeta /*data_type*/);
  44. void release_resources() override;
  45. int64_t nnz() const {
  46. return values_.size(0);
  47. }
  48. c10::SymInt sym_nnz() const {
  49. return values_.sym_size(0);
  50. }
  51. int64_t sparse_dim() const {
  52. return sparse_dim_;
  53. }
  54. int64_t dense_dim() const {
  55. return dense_dim_;
  56. }
  57. bool coalesced() const {
  58. return coalesced_;
  59. }
  60. Tensor indices() const {
  61. return indices_;
  62. }
  63. Tensor values() const {
  64. return values_;
  65. }
  66. void set_size(int64_t dim, int64_t new_size) override;
  67. void set_stride(int64_t dim, int64_t new_stride) override;
  68. void set_storage_offset(int64_t storage_offset) override;
  69. #ifdef DEBUG
  70. bool has_storage() const override;
  71. #endif
  72. // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim
  73. // with respect to indices and values
  74. void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) {
  75. TORCH_CHECK(
  76. allow_tensor_metadata_change(),
  77. "raw_resize_ ",
  78. err_msg_tensor_metadata_change_not_allowed);
  79. TORCH_CHECK(
  80. !has_symbolic_sizes_strides_,
  81. "raw_resize_ called on tensor with symbolic shape")
  82. set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
  83. sparse_dim_ = sparse_dim;
  84. dense_dim_ = dense_dim;
  85. refresh_numel();
  86. }
  87. // NOTE: This function preserves invariants of sparse_dim/dense_dim with
  88. // respect to indices and values.
  89. //
  90. // NOTE: This function supports the following cases:
  91. // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking
  92. // the size of any of the dense dimensions.
  93. // 2. When we keep the number of sparse dimensions unchanged, and NOT
  94. // shrinking the size of any of the sparse dimensions.
  95. // 3. When the sparse tensor has zero nnz, in which case we are free to change
  96. // the shapes of both its sparse and dense dimensions.
  97. //
  98. // This function DOESN'T support (and will throw an error) the following
  99. // cases:
  100. // 1. When we attempt to change the number of sparse dimensions on a non-empty
  101. // sparse tensor (such an operation will invalidate the indices stored).
  102. // 2. When we attempt to change the number of dense dimensions on a non-empty
  103. // sparse tensor (such an operation will behave differently from an equivalent
  104. // dense tensor's resize method, and for API consistency we don't support it).
  105. // 3. When we attempt to shrink the size of any of the dense dimensions on a
  106. // non-empty sparse tensor (such an operation will behave differently from an
  107. // equivalent dense tensor's resize method, and for API consistency we don't
  108. // support it).
  109. // 4. When we attempt to shrink the size of any of the sparse dimensions on a
  110. // non-empty sparse tensor (this could make some of the stored indices
  111. // out-of-bound and thus unsafe).
  112. template <typename T>
  113. void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<T> size) {
  114. TORCH_CHECK(
  115. allow_tensor_metadata_change(),
  116. "resize_ ",
  117. err_msg_tensor_metadata_change_not_allowed);
  118. TORCH_CHECK(
  119. !has_symbolic_sizes_strides_,
  120. "resize_ called on tensor with symbolic shape")
  121. TORCH_CHECK(
  122. sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
  123. "'len(size) == sparse_dim + dense_dim' is not satisfied: len(size) = ",
  124. size.size(),
  125. ", sparse_dim = ",
  126. sparse_dim,
  127. ", dense_dim = ",
  128. dense_dim);
  129. if (nnz() > 0) {
  130. [[maybe_unused]] auto constexpr alt_options_msg =
  131. "You could try the following options:\n\
  132. 1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\
  133. 2. If you need to resize this tensor, you have the following options:\n\
  134. 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\
  135. 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor.";
  136. TORCH_CHECK(
  137. sparse_dim == sparse_dim_,
  138. "changing the number of sparse dimensions (from ",
  139. sparse_dim_,
  140. " to ",
  141. sparse_dim,
  142. ") on a non-empty sparse tensor is not supported.\n",
  143. alt_options_msg);
  144. TORCH_CHECK(
  145. dense_dim == dense_dim_,
  146. "changing the number of dense dimensions (from ",
  147. dense_dim_,
  148. " to ",
  149. dense_dim,
  150. ") on a non-empty sparse tensor is not supported.\n",
  151. alt_options_msg);
  152. bool shrinking_sparse_dims = false;
  153. bool shrinking_dense_dim = false;
  154. auto sparse_size_original = generic_sizes<T>().slice(0, sparse_dim);
  155. auto sparse_size_new = size.slice(0, sparse_dim);
  156. for (const auto i : c10::irange(sparse_dim)) {
  157. if (sparse_size_new[i] < sparse_size_original[i]) {
  158. shrinking_sparse_dims = true;
  159. break;
  160. }
  161. }
  162. auto dense_size_original = generic_sizes<T>().slice(sparse_dim);
  163. auto dense_size_new = size.slice(sparse_dim);
  164. for (const auto i : c10::irange(dense_dim)) {
  165. if (dense_size_new[i] < dense_size_original[i]) {
  166. shrinking_dense_dim = true;
  167. break;
  168. }
  169. }
  170. TORCH_CHECK(
  171. !shrinking_sparse_dims,
  172. "shrinking the size of sparse dimensions (from ",
  173. sparse_size_original,
  174. " to ",
  175. sparse_size_new,
  176. ") on a non-empty sparse tensor is not supported.\n",
  177. alt_options_msg);
  178. TORCH_CHECK(
  179. !shrinking_dense_dim,
  180. "shrinking the size of dense dimensions (from ",
  181. dense_size_original,
  182. " to ",
  183. dense_size_new,
  184. ") on a non-empty sparse tensor is not supported.\n",
  185. alt_options_msg);
  186. }
  187. auto sizes_and_strides = generic_sizes<T>();
  188. const bool size_equals_sizes = std::equal(
  189. size.begin(),
  190. size.end(),
  191. sizes_and_strides.begin(),
  192. sizes_and_strides.end());
  193. if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) ||
  194. (dense_dim != dense_dim_)) {
  195. auto nnz = at::symint::sizes<T>(values())[0];
  196. std::vector<T> values_size = {nnz};
  197. auto dense_size = size.slice(sparse_dim);
  198. values_size.insert(
  199. values_size.end(), dense_size.begin(), dense_size.end());
  200. at::symint::resize_<T>(values_, values_size);
  201. at::symint::resize_<T>(indices_, {T(sparse_dim), nnz});
  202. }
  203. if (!size_equals_sizes) {
  204. set_sizes_and_strides(size, std::vector<T>(size.size()));
  205. }
  206. sparse_dim_ = sparse_dim;
  207. dense_dim_ = dense_dim;
  208. refresh_numel();
  209. }
  210. void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size) {
  211. _resize_(sparse_dim, dense_dim, size);
  212. }
  213. void resize_(
  214. int64_t sparse_dim,
  215. int64_t dense_dim,
  216. ArrayRef<c10::SymInt> size) {
  217. _resize_(sparse_dim, dense_dim, size);
  218. }
  219. // NOTE: this function will resize the sparse tensor and also set `indices`
  220. // and `values` to empty.
  221. void resize_and_clear_(
  222. int64_t sparse_dim,
  223. int64_t dense_dim,
  224. IntArrayRef size) {
  225. TORCH_CHECK(
  226. allow_tensor_metadata_change(),
  227. "resize_and_clear_ ",
  228. err_msg_tensor_metadata_change_not_allowed);
  229. TORCH_CHECK(
  230. !has_symbolic_sizes_strides_,
  231. "resize_and_clear_ called on tensor with symbolic shape")
  232. TORCH_CHECK(
  233. sparse_dim + dense_dim == static_cast<int64_t>(size.size()),
  234. "'len(size) == sparse_dim + dense_dim' is not satisfied: len(size) = ",
  235. size.size(),
  236. ", sparse_dim = ",
  237. sparse_dim,
  238. ", dense_dim = ",
  239. dense_dim);
  240. set_sizes_and_strides(size, std::vector<int64_t>(size.size()));
  241. sparse_dim_ = sparse_dim;
  242. dense_dim_ = dense_dim;
  243. auto empty_indices = at::empty({sparse_dim, 0}, indices().options());
  244. std::vector<int64_t> values_size = {0};
  245. auto dense_size = sizes().slice(sparse_dim);
  246. values_size.insert(values_size.end(), dense_size.begin(), dense_size.end());
  247. auto empty_values = at::empty(values_size, values().options());
  248. set_indices_and_values_unsafe(empty_indices, empty_values);
  249. refresh_numel();
  250. }
  251. void set_coalesced(bool coalesced) {
  252. TORCH_CHECK(
  253. allow_tensor_metadata_change(),
  254. "set_coalesced ",
  255. err_msg_tensor_metadata_change_not_allowed);
  256. coalesced_ = coalesced;
  257. }
  258. // NOTE: this function is only used internally and not exposed to Python
  259. // frontend
  260. void set_nnz_and_narrow(int64_t new_nnz) {
  261. TORCH_CHECK(
  262. allow_tensor_metadata_change(),
  263. "set_nnz_and_narrow ",
  264. err_msg_tensor_metadata_change_not_allowed);
  265. AT_ASSERT(new_nnz <= nnz());
  266. indices_ = indices_.narrow(1, 0, new_nnz);
  267. values_ = values_.narrow(0, 0, new_nnz);
  268. if (new_nnz < 2) {
  269. coalesced_ = true;
  270. }
  271. }
  272. // Takes indices and values and directly puts them into the sparse tensor, no
  273. // copy. NOTE: this function is unsafe because it doesn't check whether any
  274. // indices are out of boundaries of `sizes`, so it should ONLY be used where
  275. // we know that the indices are guaranteed to be within bounds. This used to
  276. // be called THSTensor_(_move) NB: This used to be able to avoid a refcount
  277. // bump, but I was too lazy to make it happen
  278. void set_indices_and_values_unsafe(
  279. const Tensor& indices,
  280. const Tensor& values);
  281. template <typename VariableVersion>
  282. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  283. VariableVersion&& version_counter,
  284. bool allow_tensor_metadata_change) const {
  285. const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
  286. c10::impl::PyInterpreter&& interpreter = nullptr;
  287. if (mode_stack_len > 0 &&
  288. !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
  289. const auto& cur_torch_dispatch_mode_state =
  290. c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
  291. interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
  292. } else if (
  293. key_set_.has(DispatchKey::Python) &&
  294. !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
  295. interpreter = pyobj_slot_.load_pyobj_interpreter();
  296. } else {
  297. // otherwise just copy the SparseTensorImpl and not the PyObject.
  298. auto impl = c10::make_intrusive<SparseTensorImpl>(key_set(), dtype());
  299. copy_tensor_metadata(
  300. /*src_sparse_impl=*/this,
  301. /*dest_sparse_impl=*/impl.get(),
  302. /*version_counter=*/version_counter,
  303. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  304. impl->refresh_numel();
  305. return impl;
  306. }
  307. auto r = interpreter->detach(this);
  308. r->set_version_counter(std::forward<VariableVersion>(version_counter));
  309. r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
  310. return r;
  311. }
  312. /**
  313. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  314. *
  315. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  316. * see NOTE [ TensorImpl Shallow-Copying ].
  317. */
  318. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  319. const c10::VariableVersion& version_counter,
  320. bool allow_tensor_metadata_change) const override {
  321. return shallow_copy_and_detach_core(
  322. version_counter, allow_tensor_metadata_change);
  323. }
  324. /**
  325. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  326. *
  327. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  328. * see NOTE [ TensorImpl Shallow-Copying ].
  329. */
  330. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  331. c10::VariableVersion&& version_counter,
  332. bool allow_tensor_metadata_change) const override {
  333. return shallow_copy_and_detach_core(
  334. std::move(version_counter), allow_tensor_metadata_change);
  335. }
  336. /**
  337. * Shallow-copies data from another TensorImpl into this TensorImpl.
  338. *
  339. * For why this function doesn't check this TensorImpl's
  340. * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
  341. */
  342. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  343. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  344. auto sparse_impl = static_cast<const SparseTensorImpl*>(impl.get());
  345. copy_tensor_metadata(
  346. /*src_sparse_impl=*/sparse_impl,
  347. /*dest_sparse_impl=*/this,
  348. /*version_counter=*/version_counter(),
  349. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  350. refresh_numel();
  351. }
  352. private:
  353. explicit SparseTensorImpl(
  354. at::DispatchKeySet /*key_set*/,
  355. const caffe2::TypeMeta /*data_type*/,
  356. at::Tensor indices,
  357. at::Tensor values);
  358. /**
  359. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  360. * storage_offset) from one TensorImpl to another TensorImpl.
  361. *
  362. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  363. * [ TensorImpl Shallow-Copying ].
  364. */
  365. static void copy_tensor_metadata(
  366. const SparseTensorImpl* src_sparse_impl,
  367. SparseTensorImpl* dest_sparse_impl,
  368. c10::VariableVersion version_counter,
  369. bool allow_tensor_metadata_change) {
  370. TensorImpl::copy_tensor_metadata(
  371. src_sparse_impl,
  372. dest_sparse_impl,
  373. std::move(version_counter),
  374. allow_tensor_metadata_change);
  375. // Sparse-specific fields
  376. dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim();
  377. dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim();
  378. dest_sparse_impl->indices_ = src_sparse_impl->indices();
  379. dest_sparse_impl->values_ = src_sparse_impl->values();
  380. dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced();
  381. }
  382. const char* tensorimpl_type_name() const override;
  383. };
  384. } // namespace at
  385. #else
  386. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  387. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)