CallOnce.h 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/C++17.h>
  5. #include <atomic>
  6. #include <functional>
  7. #include <mutex>
  8. #include <utility>
  9. namespace c10 {
  10. // custom c10 call_once implementation to avoid the deadlock in std::call_once.
  11. // The implementation here is a simplified version from folly and likely much
  12. // much higher memory footprint.
  13. template <typename Flag, typename F, typename... Args>
  14. inline void call_once(Flag& flag, F&& f, Args&&... args) {
  15. if (C10_LIKELY(flag.test_once())) {
  16. return;
  17. }
  18. flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
  19. }
  20. class once_flag {
  21. public:
  22. #ifndef _WIN32
  23. // running into build error on MSVC. Can't seem to get a repro locally so I'm
  24. // just avoiding constexpr
  25. //
  26. // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
  27. // defaulted default constructor cannot be constexpr because the
  28. // corresponding implicitly declared default constructor would not be
  29. // constexpr 1 error detected in the compilation of
  30. // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
  31. constexpr
  32. #endif
  33. once_flag() noexcept = default;
  34. once_flag(const once_flag&) = delete;
  35. once_flag& operator=(const once_flag&) = delete;
  36. once_flag(once_flag&&) = delete;
  37. once_flag& operator=(once_flag&&) = delete;
  38. ~once_flag() = default;
  39. bool test_once() {
  40. return init_.load(std::memory_order_acquire);
  41. }
  42. private:
  43. template <typename Flag, typename F, typename... Args>
  44. friend void call_once(Flag& flag, F&& f, Args&&... args);
  45. template <typename F, typename... Args>
  46. void call_once_slow(F&& f, Args&&... args) {
  47. std::lock_guard<std::mutex> guard(mutex_);
  48. if (init_.load(std::memory_order_relaxed)) {
  49. return;
  50. }
  51. std::invoke(std::forward<F>(f), std::forward<Args>(args)...);
  52. init_.store(true, std::memory_order_release);
  53. }
  54. void reset_once() {
  55. init_.store(false, std::memory_order_release);
  56. }
  57. private:
  58. std::mutex mutex_;
  59. std::atomic<bool> init_{false};
  60. };
  61. } // namespace c10
  62. #else
  63. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  64. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)