LeftRight.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Synchronized.h>
  5. #include <array>
  6. #include <atomic>
  7. #include <mutex>
  8. #include <thread>
  9. namespace c10 {
  10. namespace detail {
  11. struct IncrementRAII final {
  12. public:
  13. explicit IncrementRAII(std::atomic<int32_t>* counter) : _counter(counter) {
  14. _counter->fetch_add(1);
  15. }
  16. ~IncrementRAII() {
  17. _counter->fetch_sub(1);
  18. }
  19. IncrementRAII(IncrementRAII&&) = delete;
  20. IncrementRAII& operator=(IncrementRAII&&) = delete;
  21. private:
  22. std::atomic<int32_t>* _counter;
  23. C10_DISABLE_COPY_AND_ASSIGN(IncrementRAII);
  24. };
  25. } // namespace detail
  26. // LeftRight wait-free readers synchronization primitive
  27. // https://hal.archives-ouvertes.fr/hal-01207881/document
  28. //
  29. // LeftRight is quite easy to use (it can make an arbitrary
  30. // data structure permit wait-free reads), but it has some
  31. // particular performance characteristics you should be aware
  32. // of if you're deciding to use it:
  33. //
  34. // - Reads still incur an atomic write (this is how LeftRight
  35. // keeps track of how long it needs to keep around the old
  36. // data structure)
  37. //
  38. // - Writes get executed twice, to keep both the left and right
  39. // versions up to date. So if your write is expensive or
  40. // nondeterministic, this is also an inappropriate structure
  41. //
  42. // LeftRight is used fairly rarely in PyTorch's codebase. If you
  43. // are still not sure if you need it or not, consult your local
  44. // C++ expert.
  45. //
  46. template <class T>
  47. class LeftRight final {
  48. public:
  49. template <class... Args>
  50. explicit LeftRight(const Args&... args)
  51. : _counters{{{0}, {0}}},
  52. _foregroundCounterIndex(0),
  53. _foregroundDataIndex(0),
  54. _data{{T{args...}, T{args...}}} {}
  55. // Copying and moving would not be threadsafe.
  56. // Needs more thought and careful design to make that work.
  57. LeftRight(const LeftRight&) = delete;
  58. LeftRight(LeftRight&&) noexcept = delete;
  59. LeftRight& operator=(const LeftRight&) = delete;
  60. LeftRight& operator=(LeftRight&&) noexcept = delete;
  61. ~LeftRight() {
  62. // wait until any potentially running writers are finished
  63. {
  64. std::unique_lock<std::mutex> lock(_writeMutex);
  65. }
  66. // wait until any potentially running readers are finished
  67. while (_counters[0].load() != 0 || _counters[1].load() != 0) {
  68. std::this_thread::yield();
  69. }
  70. }
  71. template <typename F>
  72. auto read(F&& readFunc) const {
  73. detail::IncrementRAII _increment_counter(
  74. &_counters[_foregroundCounterIndex.load()]);
  75. return std::forward<F>(readFunc)(_data[_foregroundDataIndex.load()]);
  76. }
  77. // Throwing an exception in writeFunc is ok but causes the state to be either
  78. // the old or the new state, depending on if the first or the second call to
  79. // writeFunc threw.
  80. template <typename F>
  81. auto write(F&& writeFunc) {
  82. std::unique_lock<std::mutex> lock(_writeMutex);
  83. return _write(std::forward<F>(writeFunc));
  84. }
  85. private:
  86. template <class F>
  87. auto _write(const F& writeFunc) {
  88. /*
  89. * Assume, A is in background and B in foreground. In simplified terms, we
  90. * want to do the following:
  91. * 1. Write to A (old background)
  92. * 2. Switch A/B
  93. * 3. Write to B (new background)
  94. *
  95. * More detailed algorithm (explanations on why this is important are below
  96. * in code):
  97. * 1. Write to A
  98. * 2. Switch A/B data pointers
  99. * 3. Wait until A counter is zero
  100. * 4. Switch A/B counters
  101. * 5. Wait until B counter is zero
  102. * 6. Write to B
  103. */
  104. auto localDataIndex = _foregroundDataIndex.load();
  105. // 1. Write to A
  106. _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
  107. // 2. Switch A/B data pointers
  108. localDataIndex = localDataIndex ^ 1;
  109. _foregroundDataIndex = localDataIndex;
  110. /*
  111. * 3. Wait until A counter is zero
  112. *
  113. * In the previous write run, A was foreground and B was background.
  114. * There was a time after switching _foregroundDataIndex (B to foreground)
  115. * and before switching _foregroundCounterIndex, in which new readers could
  116. * have read B but incremented A's counter.
  117. *
  118. * In this current run, we just switched _foregroundDataIndex (A back to
  119. * foreground), but before writing to the new background B, we have to make
  120. * sure A's counter was zero briefly, so all these old readers are gone.
  121. */
  122. auto localCounterIndex = _foregroundCounterIndex.load();
  123. _waitForBackgroundCounterToBeZero(localCounterIndex);
  124. /*
  125. * 4. Switch A/B counters
  126. *
  127. * Now that we know all readers on B are really gone, we can switch the
  128. * counters and have new readers increment A's counter again, which is the
  129. * correct counter since they're reading A.
  130. */
  131. localCounterIndex = localCounterIndex ^ 1;
  132. _foregroundCounterIndex = localCounterIndex;
  133. /*
  134. * 5. Wait until B counter is zero
  135. *
  136. * This waits for all the readers on B that came in while both data and
  137. * counter for B was in foreground, i.e. normal readers that happened
  138. * outside of that brief gap between switching data and counter.
  139. */
  140. _waitForBackgroundCounterToBeZero(localCounterIndex);
  141. // 6. Write to B
  142. return _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
  143. }
  144. template <class F>
  145. auto _callWriteFuncOnBackgroundInstance(
  146. const F& writeFunc,
  147. uint8_t localDataIndex) {
  148. try {
  149. return writeFunc(_data[localDataIndex ^ 1]);
  150. } catch (...) {
  151. // recover invariant by copying from the foreground instance
  152. _data[localDataIndex ^ 1] = _data[localDataIndex];
  153. // rethrow
  154. throw;
  155. }
  156. }
  157. void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
  158. while (_counters[counterIndex ^ 1].load() != 0) {
  159. std::this_thread::yield();
  160. }
  161. }
  162. mutable std::array<std::atomic<int32_t>, 2> _counters;
  163. std::atomic<uint8_t> _foregroundCounterIndex;
  164. std::atomic<uint8_t> _foregroundDataIndex;
  165. std::array<T, 2> _data;
  166. std::mutex _writeMutex;
  167. };
  168. // RWSafeLeftRightWrapper is API compatible with LeftRight and uses a
  169. // read-write lock to protect T (data).
  170. template <class T>
  171. class RWSafeLeftRightWrapper final {
  172. public:
  173. template <class... Args>
  174. explicit RWSafeLeftRightWrapper(const Args&... args) : data_{args...} {}
  175. // RWSafeLeftRightWrapper is not copyable or moveable since LeftRight
  176. // is not copyable or moveable.
  177. RWSafeLeftRightWrapper(const RWSafeLeftRightWrapper&) = delete;
  178. RWSafeLeftRightWrapper(RWSafeLeftRightWrapper&&) noexcept = delete;
  179. RWSafeLeftRightWrapper& operator=(const RWSafeLeftRightWrapper&) = delete;
  180. RWSafeLeftRightWrapper& operator=(RWSafeLeftRightWrapper&&) noexcept = delete;
  181. ~RWSafeLeftRightWrapper() = default;
  182. template <typename F>
  183. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  184. auto read(F&& readFunc) const {
  185. return data_.withLock(
  186. [&readFunc](T const& data) { return std::forward<F>(readFunc)(data); });
  187. }
  188. template <typename F>
  189. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  190. auto write(F&& writeFunc) {
  191. return data_.withLock(
  192. [&writeFunc](T& data) { return std::forward<F>(writeFunc)(data); });
  193. }
  194. private:
  195. c10::Synchronized<T> data_;
  196. };
  197. } // namespace c10
  198. #else
  199. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  200. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)