TracerMode.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/impl/LocalDispatchKeySet.h>
  4. #include <c10/macros/Export.h>
  5. #include <c10/macros/Macros.h>
  6. // NOTE [Tracing Mode Switches]
  7. //
  8. // Historically, tracing function was controlled by two switches:
  9. //
  10. // - `AutoDispatchBelowADInplaceOrView` guard
  11. //
  12. // Tracing function used to be script-generated inside `VariableType_*.cpp`
  13. // kernels, sharing the same `Autograd` dispatch key with autograd function.
  14. // Therefore, before tracing function was moved out of VariableType,
  15. // `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a
  16. // side effect of disabling `Autograd` dispatching.
  17. //
  18. // - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h`
  19. //
  20. // It stores tracing data in a `TracingState` object in TLS. If the
  21. // `TracingState` object in TLS is `null`, then tracing is paused.
  22. //
  23. // The `TracingState` object is created in `tracer::trace()` - the main
  24. // entrance of tracing function. It's temporarily set to `null` inside
  25. // generated VariableType (now TraceType) to bypass tracing for intermediate
  26. // ops (ops being called by other ops). After the intermediate op call
  27. // finishes it's set back to the original `TracingState` object.
  28. //
  29. // The `TracingState` object in TLS can also be read/written via its Python
  30. // binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs,
  31. // which are also exposed as `TORCH_API`.
  32. //
  33. // Two new switches were introduced since tracing function was moved out of
  34. // VariableType:
  35. //
  36. // - `tracer::impl::set_dispatch_enabled()` API
  37. //
  38. // Unlike the special `Autograd` dispatch key which is included in dispatch
  39. // key set by default, `Tracer` dispatch key is off by default. The
  40. // dispatching switch can be toggled via this new API.
  41. //
  42. // - `tracer::impl::NoTracerDispatchMode` guard
  43. //
  44. // It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView`
  45. // after tracing was moved out of VariableType.
  46. //
  47. // Before tracing function was moved out of VariableType, tracing was enabled
  48. // when the following conditions are satisfied:
  49. //
  50. // 1) `TracingState` object in TLS != null;
  51. // - Either inside the execution scope of `tracer::trace()`, or
  52. // - Eagerly called `setTracingState()` with non-null object.
  53. // 2) Not inside `AutoDispatchBelowADInplaceOrView` scope;
  54. //
  55. // After:
  56. //
  57. // 1) `TracingState` object in TLS != null;
  58. // 2) Has called `tracer::impl::set_dispatch_enabled(true)`;
  59. // 3) Not inside `tracer::impl::NonDispatchGuard` scope;
  60. //
  61. // [TODOs]
  62. //
  63. // - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()`
  64. //
  65. // Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()`
  66. // to keep the semantics exactly the same as before - it's confusing to keep
  67. // both switches, though. We should consider simplifying/limiting the exposed
  68. // `setTracingState()` Python/C++ APIs (and other APIs calling it) so that
  69. // these two can be unified.
  70. //
  71. // - `AutoDispatchBelowADInplaceOrView` v.s.
  72. // `tracer::impl::NoTracerDispatchMode`
  73. //
  74. // We don't need to always set both guards together to keep semantics
  75. // unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView`
  76. // we don't need set the new tracer guard:
  77. //
  78. // * Script-generated VariableType kernels. The guard is not necessary as
  79. // tracing is already disabled explicitly by `setTracingState(null)` in
  80. // generated TraceType kernels - we could keep it as is or use the new guard
  81. // instead.
  82. //
  83. // * Custom ops. Will be handled by fallback kernel for `Tracer`.
  84. //
  85. // * Functions that are not likely to be called in tracing context (no python
  86. // binding / not an operator), e.g.: all mobile forward() wrappers, test
  87. // binaries, and etc.
  88. //
  89. // * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp.
  90. // It's not necessary as tracing is off by default.
  91. //
  92. // For the rest of cases we might need have both:
  93. //
  94. // * Functions that might be reachable from eager mode python (especially
  95. // factory methods), e.g.:
  96. // `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`.
  97. // Without the new guard it will add `aten::empty` to the traced graph.
  98. //
  99. // * Some manually maintained functions, e.g.:
  100. // `torch/csrc/autograd/VariableTypeManual.cpp`.
  101. // Set the new guard if it's not obvious whether `setTracingState(null)`
  102. // has been called before it reaches the `AutoDispatchBelowADInplaceOrView`
  103. // guard.
  104. //
  105. // We might need tweak the usage of the new guard to optimize/fix things.
  106. // It should only affect the correctness of tracing function, because the
  107. // guard is essentially no-op when the master `setTracingState()` switch is
  108. // off.
  109. // TODO: move this from `at::` to `jit::torch::` after
  110. // `aten/src/ATen/cpp_custom_type_hack.h` is removed.
  111. namespace at::tracer::impl {
  112. inline bool is_dispatch_enabled() {
  113. return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) &&
  114. !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer);
  115. }
  116. inline void set_dispatch_enabled(bool enabled) {
  117. TORCH_INTERNAL_ASSERT(
  118. !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer),
  119. "Cannot enable tracing within the scope of NoTracerDispatchMode!");
  120. c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled);
  121. }
  122. struct NoTracerDispatchMode {
  123. c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer};
  124. };
  125. } // namespace at::tracer::impl
  126. #else
  127. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  128. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)