SavedTensorHooks.h 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/SafePyObject.h>
  4. #include <c10/macros/Export.h>
  5. #include <c10/util/python_stub.h>
  6. #include <optional>
  7. #include <stack>
  8. #include <string>
  9. #include <utility>
  10. namespace at {
  11. namespace impl {
  12. struct TORCH_API SavedTensorDefaultHooksTLS {
  13. // PyObject is defined in c10/util/python_stub.h
  14. std::stack<std::pair<c10::SafePyObject, c10::SafePyObject>> stack;
  15. // See NOTE: [Disabling SavedTensorDefaultHooks] for context
  16. // NOTE: [disabled_error_message invariant]
  17. // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
  18. // We did this for efficiency (so we didn't have to keep a separate bool
  19. // around)
  20. std::optional<std::string> disabled_error_message;
  21. // See NOTE: [Deferring tensor pack/unpack hooks until runtime]
  22. bool is_tracing = false;
  23. };
  24. } // namespace impl
  25. struct TORCH_API SavedTensorDefaultHooks {
  26. static void push_hooks(
  27. c10::SafePyObject pack_hook,
  28. c10::SafePyObject unpack_hook);
  29. static std::pair<c10::SafePyObject, c10::SafePyObject> pop_hooks();
  30. static std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
  31. get_hooks(bool ignore_is_tracing = false);
  32. static void lazy_initialize();
  33. static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
  34. static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
  35. // NOTE: [Disabling SavedTensorDefaultHooks]
  36. // A developer of a PyTorch feature may choose to disable SavedTensorDefault
  37. // hooks, especially if their feature does not work with it. If they are
  38. // disabled, then the following will raise an error:
  39. // - Attempting to push_hooks
  40. // - calling disable(message) with a non-zero stack (hooks) size
  41. static void disable(
  42. const std::string& error_message,
  43. const bool fail_if_non_empty = true);
  44. static void enable();
  45. static bool is_enabled();
  46. static const std::optional<std::string>& get_disabled_error_message();
  47. // NOTE: [Deferring tensor pack/unpack hooks until runtime]
  48. // To preserve eager semantics of pack/unpack hooks firing only once per saved
  49. // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using
  50. // disable() would loud error at trace time, and pushing a no-op hook would
  51. // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx.
  52. // To do so, we disable these hooks during tracing. See
  53. // https://github.com/pytorch/pytorch/issues/113263.
  54. static bool set_tracing(bool is_tracing);
  55. };
  56. } // namespace at
  57. #else
  58. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  59. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)