ThreadLocalDebugInfo.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Export.h>
  4. #include <cstdint>
  5. #include <memory>
  6. namespace c10 {
  7. enum class C10_API_ENUM DebugInfoKind : uint8_t {
  8. PRODUCER_INFO = 0,
  9. MOBILE_RUNTIME_INFO,
  10. PROFILER_STATE,
  11. INFERENCE_CONTEXT, // for inference usage
  12. PARAM_COMMS_INFO,
  13. TEST_INFO, // used only in tests
  14. TEST_INFO_2, // used only in tests
  15. };
  16. class C10_API DebugInfoBase {
  17. public:
  18. DebugInfoBase() = default;
  19. virtual ~DebugInfoBase() = default;
  20. };
  21. // Thread local debug information is propagated across the forward
  22. // (including async fork tasks) and backward passes and is supposed
  23. // to be utilized by the user's code to pass extra information from
  24. // the higher layers (e.g. model id) down to the lower levels
  25. // (e.g. to the operator observers used for debugging, logging,
  26. // profiling, etc)
  27. class C10_API ThreadLocalDebugInfo {
  28. public:
  29. static DebugInfoBase* get(DebugInfoKind kind);
  30. // Get current ThreadLocalDebugInfo
  31. static std::shared_ptr<ThreadLocalDebugInfo> current();
  32. // Internal, use DebugInfoGuard/ThreadLocalStateGuard
  33. static void _forceCurrentDebugInfo(
  34. std::shared_ptr<ThreadLocalDebugInfo> info);
  35. // Push debug info struct of a given kind
  36. static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
  37. // Pop debug info, throws in case the last pushed
  38. // debug info is not of a given kind
  39. static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
  40. // Peek debug info, throws in case the last pushed debug info is not of the
  41. // given kind
  42. static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
  43. private:
  44. std::shared_ptr<DebugInfoBase> info_;
  45. DebugInfoKind kind_;
  46. std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
  47. friend class DebugInfoGuard;
  48. };
  49. // DebugInfoGuard is used to set debug information,
  50. // ThreadLocalDebugInfo is semantically immutable, the values are set
  51. // through the scope-based guard object.
  52. // Nested DebugInfoGuard adds/overrides existing values in the scope,
  53. // restoring the original values after exiting the scope.
  54. // Users can access the values through the ThreadLocalDebugInfo::get() call;
  55. class C10_API DebugInfoGuard {
  56. public:
  57. DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
  58. explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
  59. ~DebugInfoGuard();
  60. DebugInfoGuard(const DebugInfoGuard&) = delete;
  61. DebugInfoGuard(DebugInfoGuard&&) = delete;
  62. DebugInfoGuard& operator=(const DebugInfoGuard&) = delete;
  63. DebugInfoGuard& operator=(DebugInfoGuard&&) = delete;
  64. private:
  65. bool active_ = false;
  66. std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
  67. };
  68. } // namespace c10
  69. #else
  70. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  71. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)