| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <c10/macros/Export.h>
- #include <cstdint>
- #include <memory>
- namespace c10 {
- enum class C10_API_ENUM DebugInfoKind : uint8_t {
- PRODUCER_INFO = 0,
- MOBILE_RUNTIME_INFO,
- PROFILER_STATE,
- INFERENCE_CONTEXT, // for inference usage
- PARAM_COMMS_INFO,
- TEST_INFO, // used only in tests
- TEST_INFO_2, // used only in tests
- };
- class C10_API DebugInfoBase {
- public:
- DebugInfoBase() = default;
- virtual ~DebugInfoBase() = default;
- };
- // Thread local debug information is propagated across the forward
- // (including async fork tasks) and backward passes and is supposed
- // to be utilized by the user's code to pass extra information from
- // the higher layers (e.g. model id) down to the lower levels
- // (e.g. to the operator observers used for debugging, logging,
- // profiling, etc)
- class C10_API ThreadLocalDebugInfo {
- public:
- static DebugInfoBase* get(DebugInfoKind kind);
- // Get current ThreadLocalDebugInfo
- static std::shared_ptr<ThreadLocalDebugInfo> current();
- // Internal, use DebugInfoGuard/ThreadLocalStateGuard
- static void _forceCurrentDebugInfo(
- std::shared_ptr<ThreadLocalDebugInfo> info);
- // Push debug info struct of a given kind
- static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
- // Pop debug info, throws in case the last pushed
- // debug info is not of a given kind
- static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
- // Peek debug info, throws in case the last pushed debug info is not of the
- // given kind
- static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
- private:
- std::shared_ptr<DebugInfoBase> info_;
- DebugInfoKind kind_;
- std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
- friend class DebugInfoGuard;
- };
- // DebugInfoGuard is used to set debug information,
- // ThreadLocalDebugInfo is semantically immutable, the values are set
- // through the scope-based guard object.
- // Nested DebugInfoGuard adds/overrides existing values in the scope,
- // restoring the original values after exiting the scope.
- // Users can access the values through the ThreadLocalDebugInfo::get() call;
- class C10_API DebugInfoGuard {
- public:
- DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
- explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
- ~DebugInfoGuard();
- DebugInfoGuard(const DebugInfoGuard&) = delete;
- DebugInfoGuard(DebugInfoGuard&&) = delete;
- DebugInfoGuard& operator=(const DebugInfoGuard&) = delete;
- DebugInfoGuard& operator=(DebugInfoGuard&&) = delete;
- private:
- bool active_ = false;
- std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
- };
- } // namespace c10
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|