| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <c10/core/SafePyObject.h>
- #include <c10/macros/Export.h>
- #include <c10/util/python_stub.h>
- #include <optional>
- #include <stack>
- #include <string>
- #include <utility>
- namespace at {
- namespace impl {
- struct TORCH_API SavedTensorDefaultHooksTLS {
- // PyObject is defined in c10/util/python_stub.h
- std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
- // See NOTE: [Disabling SavedTensorDefaultHooks] for context
- // NOTE: [disabled_error_message invariant]
- // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
- // We did this for efficiency (so we didn't have to keep a separate bool
- // around)
- std::optional<std::string> disabled_error_message;
- // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- bool is_tracing = false;
- };
- } // namespace impl
- struct TORCH_API SavedTensorDefaultHooks {
- static void push_hooks(
- c10::SafePyObject pack_hook,
- c10::SafePyObject unpack_hook);
- static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
- static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
- get_hooks(bool ignore_is_tracing = false);
- static void lazy_initialize();
- static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
- static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
- // NOTE: [Disabling SavedTensorDefaultHooks]
- // A developer of a PyTorch feature may choose to disable SavedTensorDefault
- // hooks, especially if their feature does not work with it. If they are
- // disabled, then the following will raise an error:
- // - Attempting to push_hooks
- // - calling disable(message) with a non-zero stack (hooks) size
- static void disable(
- const std::string& error_message,
- const bool fail_if_non_empty = true);
- static void enable();
- static bool is_enabled();
- static const std::optional<std::string>& get_disabled_error_message();
- // NOTE: [Deferring tensor pack/unpack hooks until runtime]
- // To preserve eager semantics of pack/unpack hooks firing only once per saved
- // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
- // disable() would loud error at trace time, and pushing a no-op hook would
- // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
- // To do so, we disable these hooks during tracing. See
- // https://github.com/pytorch/pytorch/issues/113263.
- static bool set_tracing(bool is_tracing);
- };
- } // namespace at
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|