atomic.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <metal_atomic>
  4. namespace c10 {
  5. namespace metal {
  6. // Atomic operations helper
  7. template <typename T>
  8. struct AtomicType {};
  9. template <typename T>
  10. using AtomicType_t = typename AtomicType<T>::type;
  11. template <>
  12. struct AtomicType<float> {
  13. using type = ::metal::atomic<float>;
  14. static inline void atomic_add(device type* data, long offset, float value) {
  15. ::metal::atomic_fetch_add_explicit(
  16. data + offset, value, ::metal::memory_order_relaxed);
  17. }
  18. };
  19. template <>
  20. struct AtomicType<int> {
  21. using type = ::metal::atomic<int>;
  22. static inline void atomic_add(device type* data, long offset, int value) {
  23. ::metal::atomic_fetch_add_explicit(
  24. data + offset, value, ::metal::memory_order_relaxed);
  25. }
  26. };
  27. // As of Metal3.2 atomic operations are not supported on half-precision floats,
  28. // so they must be simulated Using atomic compare and exchange over 32-bit
  29. // atomic type
  30. template <typename T>
  31. static inline void atomic_add_helper(
  32. device ::metal::atomic<uint>* data,
  33. long offset,
  34. T value) {
  35. constexpr auto elem_per_enum = sizeof(uint) / sizeof(T);
  36. auto ptr = data + (offset / elem_per_enum);
  37. auto old = ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  38. union {
  39. uint i;
  40. T t[elem_per_enum];
  41. } val;
  42. do {
  43. val.i = old;
  44. val.t[offset & (elem_per_enum - 1)] += value;
  45. } while (!::metal::atomic_compare_exchange_weak_explicit(
  46. ptr,
  47. &old,
  48. val.i,
  49. ::metal::memory_order_relaxed,
  50. ::metal::memory_order_relaxed));
  51. }
  52. template <>
  53. struct AtomicType<half> {
  54. using type = ::metal::atomic<uint>;
  55. static inline void atomic_add(device type* data, long offset, half value) {
  56. atomic_add_helper(data, offset, value);
  57. }
  58. };
  59. template <>
  60. struct AtomicType<short> {
  61. using type = ::metal::atomic<uint>;
  62. static inline void atomic_add(device type* data, long offset, short value) {
  63. atomic_add_helper(data, offset, value);
  64. }
  65. };
  66. template <>
  67. struct AtomicType<char> {
  68. using type = ::metal::atomic<uint>;
  69. static inline void atomic_add(device type* data, long offset, char value) {
  70. atomic_add_helper(data, offset, value);
  71. }
  72. };
  73. template <>
  74. struct AtomicType<uchar> {
  75. using type = ::metal::atomic<uint>;
  76. static inline void atomic_add(device type* data, long offset, char value) {
  77. atomic_add_helper(data, offset, value);
  78. }
  79. };
  80. template <>
  81. struct AtomicType<bfloat> {
  82. using type = ::metal::atomic<uint>;
  83. static inline void atomic_add(device type* data, long offset, bfloat value) {
  84. atomic_add_helper<bfloat>(data, offset, value);
  85. }
  86. };
  87. // Metal supports atomic_store_explicit for bools, but
  88. // sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
  89. // atomically modify unaligned memory, so fall back to compare and exchange
  90. // trick As accumulation over booleans are just or operation, do nothing if
  91. // value is false
  92. template <>
  93. struct AtomicType<bool> {
  94. using type = ::metal::atomic<uint>;
  95. static inline void atomic_add(device type* data, long offset, bool value) {
  96. if (!value) {
  97. return;
  98. }
  99. auto ptr = data + (offset >> 2);
  100. auto old =
  101. ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  102. union {
  103. uint i;
  104. bool t[4];
  105. } val;
  106. do {
  107. val.i = old;
  108. val.t[offset & 3] = true;
  109. } while (!::metal::atomic_compare_exchange_weak_explicit(
  110. ptr,
  111. &old,
  112. val.i,
  113. ::metal::memory_order_relaxed,
  114. ::metal::memory_order_relaxed));
  115. }
  116. };
  117. // ComplexHalf atomic op
  118. template <>
  119. struct AtomicType<half2> {
  120. using type = ::metal::atomic<uint>;
  121. static inline void atomic_add(device type* data, long offset, half2 value) {
  122. auto ptr = data + offset;
  123. auto old =
  124. ::metal::atomic_load_explicit(ptr, ::metal::memory_order_relaxed);
  125. while (!::metal::atomic_compare_exchange_weak_explicit(
  126. ptr,
  127. &old,
  128. as_type<uint>(as_type<half2>(old) + value),
  129. ::metal::memory_order_relaxed,
  130. ::metal::memory_order_relaxed))
  131. ;
  132. }
  133. };
  134. // There are no atomic 64-bit add in Metal yet, but templates below implements a
  135. // consistent add I.e. if multiple threads are modify the same 64-bit value,
  136. // results stored at the address will eventually be equal to its original value
  137. // plus sum of all operands
  138. template <>
  139. struct AtomicType<long> {
  140. using type = ::metal::atomic<uint>;
  141. static inline void atomic_add(device type* data, long offset, long value) {
  142. const auto value_bits = as_type<ulong>(value);
  143. const uint low = static_cast<uint>(value_bits);
  144. uint high = static_cast<uint>(value_bits >> 32);
  145. auto ptr = data + (offset << 1);
  146. auto old_low =
  147. atomic_fetch_add_explicit(ptr, low, ::metal::memory_order_relaxed);
  148. high += (old_low + low < old_low) ? 1 : 0;
  149. atomic_fetch_add_explicit(ptr + 1, high, ::metal::memory_order_relaxed);
  150. }
  151. };
  152. // ComplexFloat atomic op, which again is not really atomic, but eventually
  153. // consistent
  154. template <>
  155. struct AtomicType<float2> {
  156. using type = ::metal::atomic<float>;
  157. static inline void atomic_add(device type* data, long offset, float2 value) {
  158. auto ptr = data + (offset << 1);
  159. atomic_fetch_add_explicit(ptr + 0, value.x, ::metal::memory_order_relaxed);
  160. atomic_fetch_add_explicit(ptr + 1, value.y, ::metal::memory_order_relaxed);
  161. }
  162. };
  163. } // namespace metal
  164. } // namespace c10
  165. #else
  166. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  167. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)