WaitCounter.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <chrono>
  4. #include <memory>
  5. #include <string_view>
  6. #include <vector>
  7. #include <c10/macros/Macros.h>
  8. #include <c10/util/ScopeExit.h>
  9. #include <c10/util/SmallVector.h>
  10. namespace c10::monitor {
  11. namespace detail {
  12. class WaitCounterImpl;
  13. class WaitCounterBackendIf {
  14. public:
  15. virtual ~WaitCounterBackendIf() = default;
  16. virtual intptr_t start(
  17. std::chrono::steady_clock::time_point now) noexcept = 0;
  18. virtual void stop(
  19. std::chrono::steady_clock::time_point now,
  20. intptr_t ctx) noexcept = 0;
  21. };
  22. class WaitCounterBackendFactoryIf {
  23. public:
  24. virtual ~WaitCounterBackendFactoryIf() = default;
  25. // May return nullptr.
  26. // In this case the counter will be ignored by the given backend.
  27. virtual std::unique_ptr<WaitCounterBackendIf> create(
  28. std::string_view key) noexcept = 0;
  29. };
  30. C10_API void registerWaitCounterBackend(
  31. std::unique_ptr<WaitCounterBackendFactoryIf> /*factory*/);
  32. C10_API std::vector<std::shared_ptr<WaitCounterBackendFactoryIf>>
  33. getRegisteredWaitCounterBackends();
  34. } // namespace detail
  35. // A handle to a wait counter.
  36. class C10_API WaitCounterHandle {
  37. public:
  38. explicit WaitCounterHandle(std::string_view key);
  39. class WaitGuard {
  40. public:
  41. WaitGuard(WaitGuard&& other) noexcept
  42. : handle_{std::exchange(other.handle_, {})},
  43. ctxs_{std::move(other.ctxs_)} {}
  44. WaitGuard(const WaitGuard&) = delete;
  45. WaitGuard& operator=(const WaitGuard&) = delete;
  46. WaitGuard& operator=(WaitGuard&&) = delete;
  47. ~WaitGuard() {
  48. stop();
  49. }
  50. void stop() {
  51. if (auto handle = std::exchange(handle_, nullptr)) {
  52. handle->stop(ctxs_);
  53. }
  54. }
  55. private:
  56. WaitGuard(WaitCounterHandle& handle, SmallVector<intptr_t>&& ctxs)
  57. : handle_{&handle}, ctxs_{std::move(ctxs)} {}
  58. friend class WaitCounterHandle;
  59. WaitCounterHandle* handle_;
  60. SmallVector<intptr_t> ctxs_;
  61. };
  62. // Starts a waiter
  63. WaitGuard start();
  64. private:
  65. // Stops the waiter. Each start() call should be matched by exactly one stop()
  66. // call.
  67. void stop(const SmallVector<intptr_t>& ctxs);
  68. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  69. detail::WaitCounterImpl& impl_;
  70. };
  71. } // namespace c10::monitor
  72. #define STATIC_WAIT_COUNTER(_key) \
  73. []() -> ::c10::monitor::WaitCounterHandle& { \
  74. static ::c10::monitor::WaitCounterHandle handle(#_key); \
  75. return handle; \
  76. }()
  77. #define STATIC_SCOPED_WAIT_COUNTER(_name) \
  78. auto C10_ANONYMOUS_VARIABLE(SCOPE_GUARD) = STATIC_WAIT_COUNTER(_name).start();
  79. #define WITH_WAIT_COUNTER(_name, _expr) \
  80. [&]() { \
  81. STATIC_SCOPED_WAIT_COUNTER(_name); \
  82. return _expr; \
  83. }();
  84. #else
  85. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  86. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)