NestedTensorImpl.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/MemoryOverlap.h>
  4. #include <ATen/Tensor.h>
  5. #include <c10/core/DispatchKey.h>
  6. #include <c10/core/DispatchKeySet.h>
  7. #include <c10/core/MemoryFormat.h>
  8. #include <c10/core/TensorImpl.h>
  9. #include <c10/util/ArrayRef.h>
  10. #include <c10/util/Exception.h>
  11. #include <c10/util/Metaprogramming.h>
  12. #include <c10/util/irange.h>
  13. namespace at::native {
  14. struct NestedTensorImpl;
  15. inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt);
  16. int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor);
  17. at::Tensor construct_nested_strides(const at::Tensor& nested_size);
  18. at::Tensor construct_offsets(const at::Tensor& nested_size);
  19. struct TORCH_API NestedTensorImpl : public c10::TensorImpl {
  20. explicit NestedTensorImpl(
  21. Storage storage,
  22. c10::DispatchKeySet key_set,
  23. const caffe2::TypeMeta data_type,
  24. at::Tensor nested_sizes,
  25. at::Tensor nested_strides,
  26. at::Tensor storage_offsets);
  27. explicit NestedTensorImpl(
  28. const at::Tensor& buffer,
  29. at::Tensor nested_sizes,
  30. at::Tensor nested_strides,
  31. at::Tensor storage_offsets);
  32. // assume contiguous, `nested_strides` and `offsets`
  33. // can be inferred from `nested_sizes`
  34. explicit NestedTensorImpl(
  35. const at::Tensor& buffer,
  36. const at::Tensor& nested_sizes);
  37. // This constructor is used creating view tensors from nested tensors
  38. explicit NestedTensorImpl(
  39. c10::TensorImpl::ImplType impl_type,
  40. const at::Tensor& base_tensor,
  41. at::Tensor nested_sizes,
  42. at::Tensor nested_strides,
  43. at::Tensor storage_offsets);
  44. // TODO: don't expose private implementation details like this; in
  45. // particular, resizing this tensor will mess up our dim() and
  46. // callers cannot fix it.
  47. const Tensor& get_nested_sizes() const {
  48. return nested_sizes_;
  49. }
  50. // TODO: don't expose private implementation details like this
  51. const Tensor& get_nested_strides() const {
  52. return nested_strides_;
  53. }
  54. const Tensor& get_storage_offsets() const {
  55. return storage_offsets_;
  56. }
  57. // Returns nullopt if the ith dimension is irregular. The ith dimension
  58. // of a NestedTensor is regular if the unbound tensors match in
  59. // size at the (i-1)th dimension.
  60. std::optional<int64_t> opt_size(int64_t d) const;
  61. int64_t size(int64_t d) const {
  62. std::optional<int64_t> optional_size = this->opt_size(d);
  63. TORCH_CHECK(
  64. optional_size.has_value(),
  65. "Given dimension ",
  66. d,
  67. " is irregular and does not have a size.");
  68. return *optional_size;
  69. }
  70. /**
  71. * Return a view of the nested tensor as a 1 dimensional contiguous tensor.
  72. *
  73. * The buffer tensor created by this function shares the same storage_impl as
  74. * the original nested tensor, and therefore can be seen as a view.
  75. *
  76. * @return A newly constructed view tensor
  77. */
  78. at::Tensor get_buffer() const {
  79. TORCH_CHECK(
  80. nested_tensor_impl_is_contiguous(this),
  81. "NestedTensor must be contiguous to get buffer.");
  82. return get_unsafe_storage_as_tensor();
  83. }
  84. /**
  85. * If possible use get_buffer() instead. This function returns the storage
  86. * as a tensor directly, which is not safe to use in general. If using this
  87. * function, The caller must ensure to account for nested_sizes,
  88. * nested_strides and storage_offsets.
  89. *
  90. * @return A newly constructed view tensor
  91. */
  92. at::Tensor get_unsafe_storage_as_tensor() const {
  93. auto buffer_key_set_ = generate_buffer_key_set();
  94. const auto buffer_size = get_buffer_size();
  95. auto buffer_tensor_impl = c10::make_intrusive<TensorImpl>(
  96. c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_);
  97. buffer_tensor_impl->set_sizes_contiguous(
  98. c10::makeArrayRef(static_cast<int64_t>(buffer_size)));
  99. return Tensor(buffer_tensor_impl);
  100. }
  101. size_t get_buffer_size() const {
  102. return storage_.nbytes() / data_type_.itemsize();
  103. }
  104. protected:
  105. const char* tensorimpl_type_name() const override;
  106. // TODO: numel_custom and is_contiguous_custom can be profitably overridden
  107. // with real implementations
  108. int64_t numel_custom() const override;
  109. c10::SymInt sym_numel_custom() const override;
  110. c10::SymBool sym_is_contiguous_custom(
  111. MemoryFormat /*memory_format*/) const override;
  112. int64_t size_custom(int64_t d) const override {
  113. return this->size(d);
  114. }
  115. c10::SymInt sym_size_custom(int64_t d) const override {
  116. return c10::SymInt{this->size(d)};
  117. }
  118. IntArrayRef sizes_custom() const override;
  119. c10::SymIntArrayRef sym_sizes_custom() const override;
  120. IntArrayRef strides_custom() const override;
  121. c10::SymIntArrayRef sym_strides_custom() const override;
  122. // this one is real
  123. int64_t dim_custom() const override;
  124. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  125. const c10::VariableVersion& version_counter,
  126. bool allow_tensor_metadata_change) const override;
  127. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  128. c10::VariableVersion&& version_counter,
  129. bool allow_tensor_metadata_change) const override;
  130. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  131. copy_tensor_metadata(
  132. /*src_impl=*/impl.get(),
  133. /*dest_impl=*/this,
  134. /*version_counter=*/version_counter(),
  135. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  136. }
  137. private:
  138. // Must be called after any changes to our dim() to sync the state
  139. // to TensorImpl.
  140. void refresh_dim();
  141. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  142. const at::Tensor nested_sizes_, nested_strides_;
  143. // The starting positions of the underlying tensors in contiguous buffer
  144. // i.e. the buffer memory offsets to get the underlying tensors
  145. // The reason to keep this metadata is that, without strong enough constraint
  146. // it cannot be derived from `nested_sizes_`
  147. // and `nested_strides_`:
  148. // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2]
  149. // this can happen e.g. after slicing a nested tensor
  150. // 2. when multiple tensors share a same memory
  151. // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2]
  152. // Some strong enough constraints are:
  153. // 1. every underlying tensor is contiguous in memory
  154. // && nesting in ascending order
  155. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  156. const at::Tensor storage_offsets_;
  157. // NOTE: -1 here means the size is missing
  158. // Optional to allow it to be computed lazily from nested.
  159. // TODO: maybe we can remove this metadata since
  160. // we can compute it from `nested_sizes_`
  161. mutable std::optional<std::vector<int64_t>> opt_sizes_;
  162. template <typename VariableVersion>
  163. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  164. VariableVersion&& version_counter,
  165. bool allow_tensor_metadata_change) const;
  166. /**
  167. * Generates a non-nested key_set from a nested tensor.
  168. *
  169. * For many nested tensor kernel implementations a buffer tensor
  170. * is generated and redispatched to a non-nested kernel this function
  171. * generates the key set used by that buffer tensor
  172. *
  173. * @return Appropriate key set for non-nested tensor
  174. */
  175. inline c10::DispatchKeySet generate_buffer_key_set() const {
  176. auto buffer_key_set = this->key_set();
  177. const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset);
  178. // Remove nested tensor specific keys
  179. buffer_key_set = buffer_key_set -
  180. c10::DispatchKeySet{
  181. c10::DispatchKey::NestedTensor,
  182. c10::DispatchKey::AutogradNestedTensor};
  183. // Add dense tensor specific keys
  184. buffer_key_set =
  185. buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense};
  186. buffer_key_set = Autograd
  187. ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set
  188. : buffer_key_set;
  189. return buffer_key_set;
  190. }
  191. };
  192. inline NestedTensorImpl* get_nested_tensor_impl_or_null(
  193. const at::Tensor& tensor) {
  194. if (tensor.is_nested()) {
  195. return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
  196. }
  197. return nullptr;
  198. }
  199. inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) {
  200. TORCH_CHECK(
  201. tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor.");
  202. return static_cast<NestedTensorImpl*>(tensor.unsafeGetTensorImpl());
  203. }
  204. inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) {
  205. int64_t ntensors = nt->size(0);
  206. if (ntensors == 0) {
  207. return true;
  208. }
  209. const Tensor &sizemat = nt->get_nested_sizes(),
  210. &stridemat = nt->get_nested_strides();
  211. const int64_t* offsets_ptr =
  212. nt->get_storage_offsets().const_data_ptr<int64_t>();
  213. int64_t orig_dim = sizemat.size(1);
  214. // nesting scalars
  215. if (orig_dim == 0) {
  216. // each scalar must be contiguous
  217. // if there is blank memory between underlying scalars
  218. for (int64_t i = 0; i < ntensors; i++) {
  219. if (offsets_ptr[i] != i) {
  220. return false;
  221. }
  222. }
  223. }
  224. // nesting tensors
  225. else {
  226. // if any underlying tensor is non-contiguous
  227. const int64_t *sizemat_ptr = sizemat.const_data_ptr<int64_t>(),
  228. *stridemat_ptr = stridemat.const_data_ptr<int64_t>();
  229. for (int64_t i = 0; i < ntensors; i++) {
  230. if (stridemat_ptr[orig_dim - 1] != 1) {
  231. return false;
  232. }
  233. int64_t product = sizemat_ptr[orig_dim - 1];
  234. for (int64_t j = orig_dim - 2; j >= 0; j--) {
  235. if (stridemat_ptr[j] != product) {
  236. return false;
  237. }
  238. product *= sizemat_ptr[j];
  239. }
  240. sizemat_ptr += orig_dim;
  241. stridemat_ptr += orig_dim;
  242. }
  243. // if there is blank memory between underlying tensors
  244. if (offsets_ptr[0] != 0) {
  245. return false;
  246. }
  247. sizemat_ptr = sizemat.const_data_ptr<int64_t>();
  248. stridemat_ptr = stridemat.const_data_ptr<int64_t>();
  249. for (int64_t i = 1; i < ntensors; i++) {
  250. if (offsets_ptr[i] !=
  251. offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) {
  252. return false;
  253. }
  254. sizemat_ptr += orig_dim;
  255. stridemat_ptr += orig_dim;
  256. }
  257. }
  258. // everything is fine
  259. return true;
  260. }
  261. inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) {
  262. return get_nested_tensor_impl(tensor)->get_nested_sizes();
  263. }
  264. } // namespace at::native
  265. #else
  266. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  267. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)