PythonTorchFunctionTLS.h 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/SafePyObject.h>
  4. #include <c10/macros/Macros.h>
  5. namespace at::impl {
  6. enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED };
  7. struct TORCH_API PythonTorchFunctionTLS {
  8. static void set_disabled_state(TorchFunctionDisabledState disabled_state_);
  9. static TorchFunctionDisabledState get_disabled_state();
  10. static void push_onto_stack(std::shared_ptr<SafePyObject> mode);
  11. static const std::shared_ptr<SafePyObject> pop_stack();
  12. static const std::shared_ptr<SafePyObject>& get_stack_at(int64_t idx);
  13. static int64_t stack_len();
  14. static const PythonTorchFunctionTLS& get_state();
  15. static void set_state(const PythonTorchFunctionTLS& state);
  16. private:
  17. // The mode TLS is split into
  18. // - disabled_state, which says which part of torch function are disabled
  19. // - stack_, which is a vector of modes representing the stack of user
  20. // defined modes
  21. TorchFunctionDisabledState disabled_state_ =
  22. TorchFunctionDisabledState::ENABLED;
  23. std::vector<std::shared_ptr<c10::SafePyObject>> stack_;
  24. friend TORCH_API bool torch_function_mode_enabled();
  25. };
  26. TORCH_API bool torch_function_mode_enabled();
  27. TORCH_API bool torch_function_all_disabled();
  28. } // namespace at::impl
  29. #else
  30. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  31. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)