FunctionalTensorWrapper.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ArrayRef.h>
  4. #include <ATen/FunctionalStorageImpl.h>
  5. #include <ATen/core/IListRef.h>
  6. #include <ATen/core/List.h>
  7. #include <ATen/core/boxing/BoxedKernel.h>
  8. #include <ATen/core/boxing/impl/boxing.h>
  9. #include <ATen/core/dispatch/Dispatcher.h>
  10. #include <c10/core/DispatchKey.h>
  11. namespace at {
  12. // Note [Functionalization Pass In Core]
  13. // The Functionalization pass is used to remove aliasing from a pytorch program.
  14. //
  15. // This is useful for backends that don't support aliasing, like XLA and Vulkan.
  16. // It's also necessary in order to remove mutation from a program, which is
  17. // needed in Functorch.
  18. //
  19. // Consider this program:
  20. // a = torch.ones(...)
  21. // b = a.view(...)
  22. // b.add_(1)
  23. //
  24. // In this program, b is meant to alias with a due to the use of view(). At the
  25. // end of the program, both a and b are full of 2's. However, backends that
  26. // don't support aliasing aren't able to correctly implement the view()
  27. // operator. Instead, they can opt into the Functionalization pass, which will
  28. // sit between the user and the backend, and provide the necessary aliasing
  29. // logic.
  30. //
  31. // The functionalization pass will turn the above program into a slightly
  32. // different program that has the same semantics, transparently to the user,
  33. // that backends like XLA/Vulkan are able to implement a = torch.ones(...) b =
  34. // a.view_copy(...) # view() replaced with view_copy(). Backends like
  35. // XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization
  36. // pass machinery knows that a and b are aliased - it applies b's mutation to a
  37. // too.
  38. //
  39. // So, how does the functionalization pass keep track of which tensors are
  40. // aliased? The pass works by wrapping EVERY tensor in the program inside of a
  41. // FunctionalTensorWrapper, which knows about its alias'd tensors.
  42. //
  43. // See Note [Functionalization: Alias Removal] for details on the aliasing
  44. // machinery. See Note [Functionalization: Mutation Removal] for details on
  45. // mutation removal.
  46. struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl {
  47. explicit FunctionalTensorWrapper(const Tensor& value);
  48. // Additional constructor to create a FunctionalTensorWrapper directly from an
  49. // underlying tensor that was created from a view. For example, the code b =
  50. // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a,
  51. // view1_meta)
  52. explicit FunctionalTensorWrapper(
  53. const Tensor& view_value,
  54. const FunctionalTensorWrapper* base,
  55. const std::shared_ptr<functionalization::ViewMeta>& meta);
  56. // Get the underlying, actual tensor, that doesn't know anything about
  57. // functionalization.
  58. const Tensor& value() const {
  59. return value_;
  60. }
  61. // The concept of "level" is only ever important to functorch; it's exposed
  62. // here as more of a hook for functorch to use.
  63. int64_t level() const {
  64. return level_;
  65. }
  66. void set_level(int64_t level) {
  67. level_ = level;
  68. }
  69. bool has_metadata_mutation() const {
  70. return has_metadata_mutation_;
  71. }
  72. uint64_t mutation_counter() const {
  73. return functional_storage_impl()->mutation_counter();
  74. }
  75. void mark_mutation() {
  76. functional_storage_impl()->mark_mutation();
  77. }
  78. // Denotes a mutation that's hidden from autograd,
  79. // e.g. for the purposes of passing a tensor to a triton kernel
  80. void mark_mutation_hidden_from_autograd() {
  81. functional_storage_impl()->mark_mutation_hidden_from_autograd();
  82. }
  83. void mark_mutation_during_no_grad_or_inference_mode() {
  84. functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode();
  85. }
  86. // Are all the mutations happening to the tensor hidden from autograd
  87. bool are_all_mutations_hidden_from_autograd() const {
  88. return functional_storage_impl()->are_all_mutations_hidden_from_autograd();
  89. }
  90. // Did all mutations happen under no_grad or inference_mode
  91. // (We also need to ignore mutations fully hidden from autograd here)
  92. bool are_all_mutations_under_no_grad_or_inference_mode() const {
  93. return functional_storage_impl()
  94. ->are_all_mutations_under_no_grad_or_inference_mode();
  95. }
  96. void maybe_mark_symbolic(functionalization::ViewMeta* meta) {
  97. is_symbolic_ = is_symbolic_ | meta->has_symbolic_inputs;
  98. }
  99. bool is_symbolic() const {
  100. return is_symbolic_;
  101. }
  102. // Retrieves the ViewMeta sequence of this tensor.
  103. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& view_metas()
  104. const;
  105. // Sync's the underlying tensor with its alias, if it's out of date. This
  106. // involves two steps: 1) Apply any pending updates/mutations to the alias 2)
  107. // Replay the views (if any) to regenerate the current tensor off of the
  108. // updated alias.
  109. void sync_();
  110. // Performs step (1) of the sync. This is its own public API because it's
  111. // needed by view_inplace ops like transpose_. See Note [Functionalization
  112. // Pass - Inplace View Ops]
  113. void regenerate_from_base();
  114. // Performs step (2) of the sync. This is its own public API because it's
  115. // needed by functorch. functorch wants to make sure that all input tensors to
  116. // a functionalized program have been properly synced so it can properly
  117. // propagate mutations to inputs. It can't just call sync_(), because the
  118. // FunctionalTensorWrapper will look like it has no aliases and sync_ will be
  119. // a noop. We use the reference count on storage_ to determine if the wrapper
  120. // is aliased, and by the time functorch is ready to propagate updates to
  121. // inputs, any intermediate views of the input created by the program will
  122. // have been deallocated. This function also returns whether or not the base
  123. // actually had any updates to apply.
  124. bool apply_updates();
  125. // Takes the current state of value_ and snapshots it, sending it as a pending
  126. // update to the alias.
  127. void commit_update();
  128. // When any tensor is mutated, the tensor increments its alias's "generation".
  129. // Separately, each tensor maintains its own "generation" counter, which is
  130. // used to determine if it's up-to-date with its alias. The act of syncing a
  131. // tensor will set a tensor's generation equal to its alias's generation.
  132. bool is_up_to_date() const;
  133. // Freezes the storage of this tensor, preventing subsequent mutations
  134. void freeze_storage() const;
  135. // Every FunctionalTensorWrapper contains a vector<ViewMeta> objects
  136. // describing the series of view ops that ran to generate the current tensor
  137. // from the base tensor. This method is used by inplace-view ops like
  138. // transpose_. It appends a ViewMeta to the existing stack, and refreshes the
  139. // tensor by replaying the views off of the alias.
  140. void mutate_view_meta(
  141. const std::shared_ptr<at::functionalization::ViewMeta>& meta);
  142. // Custom implementation of self.set_(src)
  143. void set__impl(const FunctionalTensorWrapper* other);
  144. // Custom implementation of resize_storage_bytes_(self, new_size)
  145. void storage_resize_(const c10::SymInt& new_size);
  146. // Returns whether the current tensor's data was ever mutated
  147. bool has_data_mutation();
  148. //
  149. // Returns whether the current FunctionalTensorWrapper
  150. // experienced a set_() call.
  151. bool was_storage_changed() {
  152. return was_storage_changed_;
  153. }
  154. void mark_storage_changed() {
  155. was_storage_changed_ = true;
  156. storage_changed_counter_++;
  157. }
  158. uint64_t storage_changed_counter() {
  159. return storage_changed_counter_;
  160. }
  161. // A FunctionalTensor is considered a base if its not a view of another
  162. // tensor.
  163. bool isBaseTensor() const {
  164. return view_metas_.empty();
  165. }
  166. c10::SymInt get_storage_size(bool before) {
  167. return functional_storage_impl()->get_storage_size(before);
  168. }
  169. // Returns whether the FunctionalTensor experienced an
  170. // untyped_storage().resize_() call
  171. bool was_inductor_storage_resized() {
  172. return functional_storage_impl()->was_inductor_storage_resized();
  173. }
  174. bool inductor_storage_resized_counter() {
  175. return functional_storage_impl()->inductor_storage_resized_counter();
  176. }
  177. // The functionalization pass can be used to remove mutations.
  178. // It does so by replacing any mutation op with it's corresponding
  179. // out-of-place op, followed by a call to replace_(). e.g:
  180. //
  181. // a.add_(1)
  182. //
  183. // will turn into:
  184. //
  185. // tmp = a.add(1)
  186. // a.replace_(tmp)
  187. //
  188. // replace_() swaps out the wrapped tensor, value_, with tmp.
  189. void replace_(const Tensor& other, bool from_lazy_regenerate = false);
  190. bool is_multi_output_view() {
  191. return is_multi_output_view_;
  192. }
  193. // See Note[resize_() in functionalization pass]
  194. void maybe_replace_storage(const Tensor& other);
  195. // Replaces the storage with a new functional storage,
  196. // and clears the view_metas_ stack.
  197. // WARNING: Calling this function will sever the aliasing relationship between
  198. // the current FunctionalTensorWrapper and any of its outstanding aliases.
  199. // Please only call if you know what you're doing.
  200. void _unsafe_reset_storage();
  201. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  202. const c10::VariableVersion& version_counter,
  203. bool allow_tensor_metadata_change) const override;
  204. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
  205. c10::VariableVersion&& version_counter,
  206. bool allow_tensor_metadata_change) const override;
  207. ~FunctionalTensorWrapper() override = default;
  208. // FunctionalTensorWrapper overrides all custom size/stride function,
  209. // so that if the inner tensor has a custom implementation
  210. // we make sure to call that implementation.
  211. at::IntArrayRef sizes_custom() const override;
  212. at::IntArrayRef strides_custom() const override;
  213. int64_t dim_custom() const override;
  214. int64_t numel_custom() const override;
  215. c10::SymBool sym_is_contiguous_custom(
  216. at::MemoryFormat memory_format) const override;
  217. c10::SymIntArrayRef sym_sizes_custom() const override;
  218. c10::SymInt sym_size_custom(int64_t d) const override;
  219. c10::SymIntArrayRef sym_strides_custom() const override;
  220. c10::SymInt sym_storage_offset_custom() const override;
  221. c10::Device device_custom() const override;
  222. c10::Layout layout_impl() const override;
  223. private:
  224. const char* tensorimpl_type_name() const override;
  225. void set_constructor_metadata();
  226. functionalization::FunctionalStorageImpl* functional_storage_impl() const;
  227. // This is used to re-implement shallow_copy_and_detach for
  228. // FunctionalTensorWrapper. The implementation is identical, but we just need
  229. // to return a subclass instead of a plain TensorImpl.
  230. // TODO: maybe it's possible to arrange for that to happen automatically
  231. // without an override here?
  232. template <typename VariableVersion>
  233. c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
  234. VariableVersion&& version_counter,
  235. bool allow_tensor_metadata_change) const;
  236. void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
  237. void copy_tensor_metadata_and_refresh(
  238. const FunctionalTensorWrapper* src_impl,
  239. FunctionalTensorWrapper* dest_impl,
  240. const c10::VariableVersion& version_counter,
  241. bool allow_tensor_metadata_change) const;
  242. // Note that value is not taken by reference: internally, the wrapper will
  243. // change the value tensor that it points to over time.
  244. Tensor value_;
  245. int64_t level_{};
  246. // These two counters are used for identifying
  247. // whether all the mutations on a given tensor are hidden from autograd or
  248. // not. If we have an input mutation that is hidden from autograd, then once
  249. // we convert the input mutation to a copy_() we know it will be safe to hide
  250. // the copy_() from autograd as well.
  251. bool has_metadata_mutation_ = false;
  252. bool is_multi_output_view_ = false;
  253. // Did the tensor experience a set_() call.
  254. bool was_storage_changed_ = false;
  255. uint64_t storage_changed_counter_ = 0;
  256. // Did the tensor experience any view operation with symbolic int.
  257. bool is_symbolic_ = false;
  258. size_t generation_ = 0;
  259. std::vector<std::shared_ptr<at::functionalization::ViewMeta>> view_metas_;
  260. protected:
  261. static void copy_tensor_metadata(
  262. const FunctionalTensorWrapper* src_impl,
  263. FunctionalTensorWrapper* dest_impl,
  264. const c10::VariableVersion& version_counter,
  265. bool allow_tensor_metadata_change);
  266. };
  267. // Utility functions for the functionalization pass.
  268. namespace functionalization {
  269. namespace impl {
  270. inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper(
  271. const Tensor& tensor) {
  272. auto functional_impl =
  273. static_cast<FunctionalTensorWrapper*>(tensor.unsafeGetTensorImpl());
  274. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr);
  275. return functional_impl;
  276. }
  277. TORCH_API bool isBaseTensor(const at::Tensor& tensor);
  278. TORCH_API bool isFunctionalTensor(const at::Tensor& tensor);
  279. TORCH_API bool isFunctionalTensor(const std::optional<Tensor>& t);
  280. TORCH_API bool isFunctionalTensor(
  281. const c10::List<std::optional<Tensor>>& t_list);
  282. TORCH_API bool isFunctionalTensor(ITensorListRef list);
  283. TORCH_API Tensor to_functional_tensor(const Tensor& tensor);
  284. TORCH_API std::optional<Tensor> to_functional_tensor(
  285. const std::optional<Tensor>& tensor);
  286. TORCH_API c10::List<std::optional<Tensor>> to_functional_tensor(
  287. const c10::List<std::optional<Tensor>>& t_list);
  288. TORCH_API std::vector<Tensor> to_functional_tensor(ITensorListRef t_list);
  289. TORCH_API void freeze_functional_tensor(const Tensor& tensor);
  290. TORCH_API Tensor
  291. from_functional_tensor(const Tensor& tensor, bool assert_functional = true);
  292. TORCH_API std::optional<Tensor> from_functional_tensor(
  293. const std::optional<Tensor>& t,
  294. bool assert_functional = true);
  295. TORCH_API c10::List<std::optional<Tensor>> from_functional_tensor(
  296. const c10::List<std::optional<Tensor>>& t_list);
  297. TORCH_API std::vector<Tensor> from_functional_tensor(ITensorListRef t_list);
  298. TORCH_API void sync(const at::Tensor& t);
  299. TORCH_API void sync(const std::optional<Tensor>& t);
  300. TORCH_API void sync(const c10::List<std::optional<Tensor>>& t_list);
  301. TORCH_API void sync(ITensorListRef t_list);
  302. TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other);
  303. TORCH_API void replace_(
  304. const ITensorListRef functional_tensor,
  305. ITensorListRef other);
  306. TORCH_API void commit_update(const Tensor& functional_tensor);
  307. TORCH_API void commit_update(ITensorListRef functional_tensor);
  308. TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor);
  309. TORCH_API void mark_mutation_hidden_from_autograd(
  310. const Tensor& functional_tensor);
  311. TORCH_API bool are_all_mutations_hidden_from_autograd(
  312. const Tensor& functional_tensor);
  313. TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode(
  314. const Tensor& functional_tensor);
  315. // These two methods are XLA-specific logic and are no-ops
  316. // for the normal functionalization flow.
  317. TORCH_API void propagate_xla_data(
  318. const Tensor& functional_tensor,
  319. const Tensor& other);
  320. TORCH_API void propagate_xla_data(
  321. const ITensorListRef functional_tensor,
  322. ITensorListRef other);
  323. TORCH_API void propagate_xla_data_direct(
  324. const Tensor& tensor,
  325. const Tensor& other);
  326. TORCH_API void propagate_xla_data_direct(
  327. const ITensorListRef tensor,
  328. ITensorListRef other);
  329. Tensor create_functional_tensor_with_view_meta(
  330. const Tensor& view_to_wrap,
  331. const Tensor& base,
  332. const std::shared_ptr<functionalization::ViewMeta>& meta,
  333. int64_t out_idx = 0);
  334. std::vector<Tensor> create_functional_tensor_with_view_meta(
  335. ITensorListRef view_to_wrap,
  336. const Tensor& base,
  337. const std::shared_ptr<functionalization::ViewMeta>& meta);
  338. void mutate_view_meta(
  339. const Tensor& self,
  340. const std::shared_ptr<functionalization::ViewMeta>& meta);
  341. TORCH_API Tensor apply_view_meta_sequence(
  342. const Tensor& base,
  343. const std::vector<std::shared_ptr<functionalization::ViewMeta>>& sequence);
  344. void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out);
  345. void set_sizes_strides_offset(
  346. const std::vector<Tensor>& outs,
  347. const std::vector<Tensor>& meta_outs);
  348. // ~~~~~ TLS used in functionalization ~~~~~
  349. TORCH_API bool getFunctionalizationReapplyViewsTLS();
  350. TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views);
  351. class TORCH_API FunctionalizationReapplyViewsGuard {
  352. public:
  353. FunctionalizationReapplyViewsGuard(bool reapply_views)
  354. : prev_(getFunctionalizationReapplyViewsTLS()) {
  355. setFunctionalizationReapplyViewsTLS(reapply_views);
  356. }
  357. ~FunctionalizationReapplyViewsGuard() {
  358. setFunctionalizationReapplyViewsTLS(prev_);
  359. }
  360. FunctionalizationReapplyViewsGuard(
  361. const FunctionalizationReapplyViewsGuard&) = delete;
  362. FunctionalizationReapplyViewsGuard operator=(
  363. const FunctionalizationReapplyViewsGuard&) = delete;
  364. FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) =
  365. delete;
  366. FunctionalizationReapplyViewsGuard operator=(
  367. FunctionalizationReapplyViewsGuard&&) = delete;
  368. private:
  369. bool prev_;
  370. };
  371. } // namespace impl
  372. // Helper function to call an out-of-place composite aten kernel that may use
  373. // mutations / views internally, and functionalize them.
  374. TORCH_API void functionalize_op_helper(
  375. const c10::OperatorHandle& op,
  376. torch::jit::Stack* stack);
  377. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  378. struct _functionalize_aten_op final {};
  379. template <class Op, bool symint, class ReturnType, class... ParameterTypes>
  380. struct _functionalize_aten_op<Op, symint, ReturnType(ParameterTypes...)> final {
  381. static ReturnType call(
  382. typename c10::maybe_keep_symint<symint, ParameterTypes>::type... args) {
  383. using FuncType = ReturnType(
  384. typename c10::maybe_keep_symint<symint, ParameterTypes>::type...);
  385. auto op = c10::Dispatcher::singleton()
  386. .findSchemaOrThrow(
  387. (const char*)Op::name, (const char*)Op::overload_name)
  388. .typed<FuncType>();
  389. return c10::impl::BoxedKernelWrapper<FuncType>::call(
  390. c10::BoxedKernel::makeFromFunction<functionalize_op_helper>(),
  391. op,
  392. // BoxedKernelWrapper knows to ignore this keyset argument,
  393. // because functionalize_op_helper doesn't take in a DispatchKeySet
  394. c10::DispatchKeySet(),
  395. args...);
  396. }
  397. };
  398. template <class Op>
  399. using functionalize_aten_op =
  400. _functionalize_aten_op<Op, false, typename Op::schema>;
  401. template <class Op>
  402. using functionalize_aten_op_symint =
  403. _functionalize_aten_op<Op, true, typename Op::schema>;
  404. } // namespace functionalization
  405. } // namespace at
  406. #else
  407. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  408. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)