SparseCsrTensorImpl.h 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/TensorImpl.h>
  5. #include <c10/core/impl/TorchDispatchModeTLS.h>
  6. #include <c10/util/Exception.h>
  7. namespace at {
  8. // Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
  9. // denoting the data: `crow_indices_`, `col_indices_` and `values_`.
  10. // The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)`
  11. // that represents the compressed row indices of the CSR tensor. The
  12. // `col_indices_` tensor is an integer tensor of shape `(nnz())`
  13. // that explicitly stores the column indices of each value of the sparse
  14. // tensor. The `values_` tensor can be of any pytorch-supported data type
  15. // and has shape `(nnz())`.
  16. //
  17. // Since the main advantage of the CSR format over the COO format is speed of
  18. // computation, care must be taken to facilitate smooth interfacing of
  19. // these data structures with optimized libraries such as MKL and MAGMA.
  20. // Since the MKL interface for pytorch currently uses indexing with int32
  21. // type, it is important to make sure that the `crow_indices` and `col_indices`
  22. // are of type int32 when calling MKL routines such as SPMM or SPMV.
  23. //
  24. // If not calling MKL, it should be alright to use 64 bit integer tensors
  25. // for indexing.
  26. struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
  27. Tensor crow_indices_;
  28. Tensor col_indices_;
  29. Tensor values_;
  30. Layout layout_;
  31. public:
  32. explicit SparseCsrTensorImpl(
  33. at::DispatchKeySet /*key_set*/,
  34. at::Device device,
  35. Layout layout,
  36. const caffe2::TypeMeta /*data_type*/);
  37. void resize_(int64_t nnz, IntArrayRef size);
  38. void resize_and_clear_(
  39. int64_t sparse_dim,
  40. int64_t dense_dim,
  41. IntArrayRef size);
  42. void resize_as_sparse_compressed_tensor_(const Tensor& src);
  43. void set_member_tensors(
  44. const Tensor& crow_indices,
  45. const Tensor& col_indices,
  46. const Tensor& values,
  47. c10::SymIntArrayRef size);
  48. void set_member_tensors(
  49. const Tensor& crow_indices,
  50. const Tensor& col_indices,
  51. const Tensor& values,
  52. IntArrayRef size);
  53. const Tensor& compressed_indices() const {
  54. return crow_indices_;
  55. }
  56. const Tensor& plain_indices() const {
  57. return col_indices_;
  58. }
  59. const Tensor& values() const {
  60. return values_;
  61. }
  62. int64_t nnz() {
  63. return col_indices_.size(-1);
  64. }
  65. inline int64_t batch_dim() const noexcept {
  66. return crow_indices_.dim() - 1;
  67. }
  68. inline int64_t sparse_dim() const noexcept {
  69. return 2;
  70. }
  71. inline int64_t dense_dim() const noexcept {
  72. return values_.dim() - batch_dim() - block_dim() - 1;
  73. }
  74. private:
  75. inline int64_t block_dim() const noexcept {
  76. return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0);
  77. }
  78. protected:
  79. IntArrayRef strides_custom() const override;
  80. SymIntArrayRef sym_strides_custom() const override;
  81. SymBool sym_is_contiguous_custom(
  82. MemoryFormat /*memory_format*/) const override;
  83. public:
  84. void set_size(int64_t dim, int64_t new_size) override;
  85. void set_stride(int64_t dim, int64_t new_stride) override;
  86. void set_storage_offset(int64_t storage_offset) override;
  87. Layout layout_impl() const override {
  88. return layout_;
  89. }
  90. void set_layout(Layout layout) {
  91. switch (layout) {
  92. case kSparseCsr:
  93. case kSparseCsc:
  94. case kSparseBsr:
  95. case kSparseBsc:
  96. layout_ = layout;
  97. break;
  98. default:
  99. TORCH_CHECK(false, "unsupported layout ", layout);
  100. }
  101. }
  102. template <typename VariableVersion>
  103. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  104. VariableVersion&& version_counter,
  105. bool allow_tensor_metadata_change) const {
  106. const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
  107. c10::impl::PyInterpreter&& interpreter = nullptr;
  108. if (mode_stack_len > 0 &&
  109. !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
  110. const auto& cur_torch_dispatch_mode_state =
  111. c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
  112. interpreter = cur_torch_dispatch_mode_state->pyinterpreter();
  113. } else if (
  114. key_set_.has(DispatchKey::Python) &&
  115. !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
  116. interpreter = pyobj_slot_.load_pyobj_interpreter();
  117. } else {
  118. // otherwise just copy the SparseTensorImpl and not the PyObject.
  119. auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
  120. key_set(), device(), layout_impl(), dtype());
  121. copy_tensor_metadata(
  122. /*src_sparse_impl=*/this,
  123. /*dest_sparse_impl=*/impl.get(),
  124. /*version_counter=*/version_counter,
  125. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  126. impl->refresh_numel();
  127. return impl;
  128. }
  129. auto r = interpreter->detach(this);
  130. r->set_version_counter(std::forward<VariableVersion>(version_counter));
  131. r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
  132. return r;
  133. }
  134. /**
  135. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  136. *
  137. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  138. * see NOTE [ TensorImpl Shallow-Copying ].
  139. */
  140. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  141. const c10::VariableVersion& version_counter,
  142. bool allow_tensor_metadata_change) const override {
  143. return shallow_copy_and_detach_core(
  144. version_counter, allow_tensor_metadata_change);
  145. }
  146. /**
  147. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  148. *
  149. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  150. * see NOTE [ TensorImpl Shallow-Copying ].
  151. */
  152. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  153. c10::VariableVersion&& version_counter,
  154. bool allow_tensor_metadata_change) const override {
  155. return shallow_copy_and_detach_core(
  156. std::move(version_counter), allow_tensor_metadata_change);
  157. }
  158. private:
  159. explicit SparseCsrTensorImpl(
  160. at::DispatchKeySet key_set,
  161. const caffe2::TypeMeta data_type,
  162. at::Tensor crow_indices,
  163. at::Tensor col_indices,
  164. at::Tensor values,
  165. at::Layout layout);
  166. const char* tensorimpl_type_name() const override;
  167. /**
  168. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  169. * storage_offset) from one TensorImpl to another TensorImpl.
  170. *
  171. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  172. * [ TensorImpl Shallow-Copying ].
  173. */
  174. static void copy_tensor_metadata(
  175. const SparseCsrTensorImpl* src_sparse_impl,
  176. SparseCsrTensorImpl* dest_sparse_impl,
  177. c10::VariableVersion version_counter,
  178. bool allow_tensor_metadata_change) {
  179. TensorImpl::copy_tensor_metadata(
  180. src_sparse_impl,
  181. dest_sparse_impl,
  182. std::move(version_counter),
  183. allow_tensor_metadata_change);
  184. // Sparse-specific fields
  185. dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices();
  186. dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices();
  187. dest_sparse_impl->values_ = src_sparse_impl->values();
  188. dest_sparse_impl->layout_ = src_sparse_impl->layout_impl();
  189. }
  190. };
  191. } // namespace at
  192. #else
  193. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  194. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)