ThreadLocalState.h 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/InferenceMode.h>
  4. #include <c10/core/impl/LocalDispatchKeySet.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/ThreadLocalDebugInfo.h>
  7. #include <ATen/FuncTorchTLS.h>
  8. #include <ATen/PythonTorchFunctionTLS.h>
  9. #include <ATen/SavedTensorHooks.h>
  10. #include <ATen/ThreadLocalPythonObjects.h>
  11. #include <ATen/record_function.h>
  12. #include <c10/core/impl/PythonDispatcherTLS.h>
  13. #include <c10/core/impl/TorchDispatchModeTLS.h>
  14. namespace at {
  15. // Thread local state contains values that are preserved across
  16. // thread boundaries (e.g. at::launch/JIT fork, autograd).
  17. // Note at::parallel_for doesn't preserve TLS across thread boundaries.
  18. class TORCH_API ThreadLocalState {
  19. public:
  20. // Saves the thread local variables' values and
  21. // returns them as a ThreadLocalState
  22. ThreadLocalState();
  23. // set_grad_mode - force the value of the grad mode TLS in
  24. // the current state object. This is used for example in the
  25. // autograd engine.
  26. void set_grad_mode(bool enabled);
  27. // set_multithreading_enabled - force the value of the multithreadinmaximum
  28. // threads TLS in
  29. // the current state object. This is used for example in the
  30. // autograd engine.
  31. void set_multithreading_enabled(bool enabled);
  32. // Sets thread local variables in the current thread,
  33. // according to the thread boundary specified
  34. static void setThreadLocalState(const ThreadLocalState& state);
  35. private:
  36. c10::impl::LocalDispatchKeySet dispatch_key_;
  37. // ThreadLocalDebugInfo does not change after being created
  38. // with DebugInfoGuard
  39. std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
  40. // RecordFunction TLS
  41. RecordFunctionTLS rf_tls_;
  42. // TLS for out-of-tree functorch
  43. // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
  44. // pointer (spoiler alert: it's due to the indirection)
  45. // This needs to be a shared_ptr instead of a unique_ptr because
  46. // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
  47. // consider adding an explicit copy constructor for ThreadLocalState in the
  48. // future but I didn't want to add one just for this.
  49. std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
  50. // TLS for AutogradModes
  51. AutogradState autograd_tls_;
  52. // TLS for enable_torch_dispatch_mode
  53. c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
  54. // TLS for enable_python_dispatcher
  55. c10::impl::PyInterpreter* python_dispatcher_state_;
  56. // TLS for __torch_function__ (mode and disable_torch_function)
  57. at::impl::PythonTorchFunctionTLS python_torch_function_state_;
  58. // TLS for saved tensors default hooks
  59. at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
  60. bool functionalization_reapply_views_state_;
  61. bool dtensor_allow_implicit_replication_;
  62. // TLS for arbitrary python objects that is registered via hooks
  63. at::impl::ThreadLocalPythonObjects saved_objects_;
  64. #if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \
  65. !defined(BUILD_LITE_INTERPRETER)
  66. // TLS for autocast dtypes
  67. std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
  68. autocast_dtypes_{};
  69. #endif
  70. friend class ThreadLocalStateGuard;
  71. };
  72. // Guard to set and reset the thread local state
  73. class TORCH_API ThreadLocalStateGuard {
  74. public:
  75. explicit ThreadLocalStateGuard(const ThreadLocalState& state)
  76. : prev_state_(ThreadLocalState()) {
  77. // set the given state across the thread boundary
  78. ThreadLocalState::setThreadLocalState(state);
  79. }
  80. ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete;
  81. ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete;
  82. ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete;
  83. ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete;
  84. ~ThreadLocalStateGuard() {
  85. // restore previously set variables
  86. ThreadLocalState::setThreadLocalState(prev_state_);
  87. }
  88. private:
  89. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  90. const ThreadLocalState prev_state_;
  91. };
  92. template <typename T>
  93. auto wrapPropagateTLSState(T callback) {
  94. return [tls_state = ThreadLocalState(),
  95. callback = std::move(callback)](auto&&... args) {
  96. ThreadLocalStateGuard g(tls_state);
  97. // Propagate value returned by callback().
  98. return callback(std::forward<decltype(args)>(args)...);
  99. };
  100. }
  101. } // namespace at
  102. #else
  103. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  104. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)