OpaqueTensorImpl.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/MemoryFormat.h>
  4. #include <c10/core/SymIntArrayRef.h>
  5. #include <c10/core/TensorImpl.h>
  6. #include <c10/util/Exception.h>
  7. namespace at {
  8. // An "Opaque" TensorImpl -- there are no strides and (for now)
  9. // even data() is not supported (thus no pointer arithmetic).
  10. // NOTE: We could allow data() in the future, but would have to ensure pointer
  11. // arithmetic code is properly guarded.
  12. //
  13. // NOTE: This does not support resize_ (and other metadata-changing ops) because
  14. // of `shallow_copy_and_detach`. We would need to define an interface to
  15. // "shallow copy" in order to add support.
  16. template <typename OpaqueHandle>
  17. struct TORCH_API OpaqueTensorImpl : public TensorImpl {
  18. // public constructor for now...
  19. OpaqueTensorImpl(
  20. at::DispatchKeySet key_set,
  21. const caffe2::TypeMeta data_type,
  22. c10::Device device,
  23. OpaqueHandle opaque_handle,
  24. c10::IntArrayRef sizes,
  25. bool is_non_overlapping_and_dense = true)
  26. : TensorImpl(key_set, data_type, device),
  27. opaque_handle_(std::move(opaque_handle)) {
  28. constructor_impl(sizes, is_non_overlapping_and_dense);
  29. }
  30. OpaqueTensorImpl(
  31. TensorImpl::ImplType impl_type,
  32. c10::Storage&& storage,
  33. at::DispatchKeySet key_set,
  34. const caffe2::TypeMeta data_type,
  35. OpaqueHandle opaque_handle,
  36. c10::IntArrayRef sizes,
  37. bool is_non_overlapping_and_dense = true)
  38. : TensorImpl(impl_type, std::move(storage), key_set, data_type),
  39. opaque_handle_(std::move(opaque_handle)) {
  40. constructor_impl(sizes, is_non_overlapping_and_dense);
  41. }
  42. // Destructor doesn't call release_resources because it's
  43. // unnecessary; don't forget to change that if needed!
  44. void release_resources() override {
  45. TensorImpl::release_resources();
  46. opaque_handle_ = {};
  47. }
  48. void set_size(int64_t dim, int64_t new_size) override {
  49. TORCH_CHECK(false, "opaque tensors do not have set_size");
  50. }
  51. void set_stride(int64_t dim, int64_t new_stride) override {
  52. TORCH_CHECK(false, "opaque tensors do not have set_stride");
  53. }
  54. void set_storage_offset(int64_t storage_offset) override {
  55. TORCH_CHECK(false, "opaque tensors do not have set_storage_offset");
  56. }
  57. #ifdef DEBUG
  58. bool has_storage() const override {
  59. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
  60. !storage_, "OpaqueTensorImpl assumes that storage_ is never set");
  61. return false;
  62. }
  63. #endif
  64. /**
  65. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  66. *
  67. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  68. * see NOTE [ TensorImpl Shallow-Copying ].
  69. */
  70. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  71. const c10::VariableVersion& version_counter,
  72. bool allow_tensor_metadata_change) const override {
  73. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  74. key_set(),
  75. dtype(),
  76. device(),
  77. opaque_handle_,
  78. sizes_and_strides_.sizes_arrayref());
  79. copy_tensor_metadata(
  80. /*src_opaque_impl=*/this,
  81. /*dest_opaque_impl=*/impl.get(),
  82. /*version_counter=*/version_counter,
  83. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  84. impl->refresh_numel();
  85. return impl;
  86. }
  87. /**
  88. * Return a TensorImpl that is a shallow-copy of this TensorImpl.
  89. *
  90. * For usage of `version_counter` and `allow_tensor_metadata_change`,
  91. * see NOTE [ TensorImpl Shallow-Copying ].
  92. */
  93. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  94. c10::VariableVersion&& version_counter,
  95. bool allow_tensor_metadata_change) const override {
  96. auto impl = c10::make_intrusive<OpaqueTensorImpl<OpaqueHandle>>(
  97. key_set(),
  98. dtype(),
  99. device(),
  100. opaque_handle_,
  101. sizes_and_strides_.sizes_arrayref());
  102. copy_tensor_metadata(
  103. /*src_opaque_impl=*/this,
  104. /*dest_opaque_impl=*/impl.get(),
  105. /*version_counter=*/std::move(version_counter),
  106. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
  107. impl->refresh_numel();
  108. return impl;
  109. }
  110. /**
  111. * Shallow-copies data from another TensorImpl into this TensorImpl.
  112. *
  113. * For why this function doesn't check this TensorImpl's
  114. * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
  115. */
  116. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override {
  117. AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
  118. auto opaque_impl =
  119. static_cast<const OpaqueTensorImpl<OpaqueHandle>*>(impl.get());
  120. copy_tensor_metadata(
  121. /*src_impl=*/opaque_impl,
  122. /*dest_impl=*/this,
  123. /*version_counter=*/version_counter(),
  124. /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
  125. refresh_numel();
  126. }
  127. const OpaqueHandle& opaque_handle() const {
  128. return opaque_handle_;
  129. }
  130. OpaqueHandle& unsafe_opaque_handle() {
  131. return opaque_handle_;
  132. }
  133. protected:
  134. /**
  135. * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
  136. * storage_offset) from one TensorImpl to another TensorImpl.
  137. *
  138. * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
  139. * [ TensorImpl Shallow-Copying ].
  140. */
  141. static void copy_tensor_metadata(
  142. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  143. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  144. const c10::VariableVersion& version_counter,
  145. bool allow_tensor_metadata_change) {
  146. TensorImpl::copy_tensor_metadata(
  147. src_opaque_impl,
  148. dest_opaque_impl,
  149. version_counter,
  150. allow_tensor_metadata_change);
  151. // OpaqueTensorImpl-specific fields.
  152. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  153. }
  154. static void copy_tensor_metadata(
  155. const OpaqueTensorImpl<OpaqueHandle>* src_opaque_impl,
  156. OpaqueTensorImpl<OpaqueHandle>* dest_opaque_impl,
  157. c10::VariableVersion&& version_counter,
  158. bool allow_tensor_metadata_change) {
  159. TensorImpl::copy_tensor_metadata(
  160. src_opaque_impl,
  161. dest_opaque_impl,
  162. std::move(version_counter),
  163. allow_tensor_metadata_change);
  164. // OpaqueTensorImpl-specific fields.
  165. dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_;
  166. }
  167. private:
  168. const char* tensorimpl_type_name() const override {
  169. return "OpaqueTensorImpl";
  170. }
  171. void constructor_impl(
  172. c10::IntArrayRef sizes,
  173. bool is_non_overlapping_and_dense) {
  174. set_storage_access_should_throw();
  175. set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
  176. sizes_and_strides_.set_sizes(sizes);
  177. refresh_numel();
  178. // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
  179. is_non_overlapping_and_dense_ = is_non_overlapping_and_dense;
  180. }
  181. OpaqueHandle opaque_handle_;
  182. };
  183. } // namespace at
  184. #else
  185. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  186. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)