Atomic.cuh 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cuda.h>
  4. #include <c10/util/Half.h>
  5. #include <c10/util/BFloat16.h>
  6. #include <ATen/NumericUtils.h>
  7. #if !(defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
  8. #include <cuda_bf16.h>
  9. #endif
  10. template <typename T>
  11. struct AtomicFPOp;
  12. template <>
  13. struct AtomicFPOp<at::Half> {
  14. template <typename func_t>
  15. inline __device__ at::Half operator() (at::Half *address, at::Half val, const func_t& func) {
  16. unsigned int * address_as_ui =
  17. (unsigned int *) ((char *)address - ((size_t)address & 2));
  18. unsigned int old = *address_as_ui;
  19. unsigned int assumed;
  20. at::Half hsum;
  21. do {
  22. assumed = old;
  23. hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  24. hsum = func(hsum, val);
  25. old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
  26. old = atomicCAS(address_as_ui, assumed, old);
  27. } while (assumed != old);
  28. hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  29. return hsum;
  30. }
  31. };
  32. template <>
  33. struct AtomicFPOp<at::BFloat16> {
  34. template <typename func_t>
  35. inline __device__ at::BFloat16 operator() (at::BFloat16 *address, at::BFloat16 val, const func_t& func) {
  36. unsigned int * address_as_ui =
  37. (unsigned int *) ((char *)address - ((size_t)address & 2));
  38. unsigned int old = *address_as_ui;
  39. unsigned int assumed;
  40. at::BFloat16 bsum;
  41. do {
  42. assumed = old;
  43. bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  44. bsum = func(bsum, val);
  45. old = (size_t)address & 2 ? (old & 0xffff) | (bsum.x << 16) : (old & 0xffff0000) | bsum.x;
  46. old = atomicCAS(address_as_ui, assumed, old);
  47. } while (assumed != old);
  48. bsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
  49. return bsum.x;
  50. }
  51. };
  52. template <>
  53. struct AtomicFPOp<double> {
  54. template <typename func_t>
  55. inline __device__ double operator() (double * address, double val, const func_t& func) {
  56. unsigned long long int* address_as_ull = (unsigned long long int*)address;
  57. unsigned long long int old = *address_as_ull;
  58. unsigned long long int assumed;
  59. do {
  60. assumed = old;
  61. old = atomicCAS(address_as_ull, assumed, func(val, assumed));
  62. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  63. } while (assumed != old);
  64. return __longlong_as_double(old);
  65. }
  66. };
  67. #define ATOMIC_INTEGER_IMPL(NAME) \
  68. template <typename T, size_t n> \
  69. struct Atomic##NAME##IntegerImpl; \
  70. \
  71. template<typename T> \
  72. struct Atomic##NAME##IntegerImpl<T, 1> { \
  73. template <typename func_t> \
  74. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  75. size_t offset = (size_t)address & 3; \
  76. uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
  77. uint32_t old = *address_as_ui; \
  78. uint32_t shift = offset * 8; \
  79. uint32_t old_byte; \
  80. uint32_t newval; \
  81. uint32_t assumed; \
  82. \
  83. do { \
  84. assumed = old; \
  85. old_byte = (old >> shift) & 0xff; \
  86. newval = static_cast<uint8_t>(func(val, static_cast<T>(old_byte))); \
  87. newval = (old & ~(0x000000ff << shift)) | (newval << shift); \
  88. old = atomicCAS(address_as_ui, assumed, newval); \
  89. } while (assumed != old); \
  90. } \
  91. }; \
  92. \
  93. template<typename T> \
  94. struct Atomic##NAME##IntegerImpl<T, 2> { \
  95. template <typename func_t> \
  96. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  97. size_t offset = (size_t)address & 2; \
  98. uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \
  99. bool is_32_align = offset; \
  100. uint32_t old = *address_as_ui; \
  101. uint32_t old_bytes; \
  102. uint32_t newval; \
  103. uint32_t assumed; \
  104. \
  105. do { \
  106. assumed = old; \
  107. old_bytes = is_32_align ? old >> 16 : old & 0xffff; \
  108. newval = static_cast<uint16_t>(func(val, static_cast<T>(old_bytes))); \
  109. newval = is_32_align ? (old & 0xffff) | (newval << 16) : (old & 0xffff0000) | newval; \
  110. old = atomicCAS(address_as_ui, assumed, newval); \
  111. } while (assumed != old); \
  112. } \
  113. }; \
  114. \
  115. template<typename T> \
  116. struct Atomic##NAME##IntegerImpl<T, 4> { \
  117. template <typename func_t> \
  118. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  119. uint32_t * address_as_ui = (uint32_t *) (address); \
  120. uint32_t old = *address_as_ui; \
  121. uint32_t newval; \
  122. uint32_t assumed; \
  123. \
  124. do { \
  125. assumed = old; \
  126. newval = static_cast<uint32_t>(func(val, static_cast<T>(old))); \
  127. old = atomicCAS(address_as_ui, assumed, newval); \
  128. } while (assumed != old); \
  129. } \
  130. }; \
  131. \
  132. template<typename T> \
  133. struct Atomic##NAME##IntegerImpl<T, 8> { \
  134. template <typename func_t> \
  135. inline __device__ void operator()(T *address, T val, const func_t& func) { \
  136. unsigned long long * address_as_ui = (unsigned long long *) (address); \
  137. unsigned long long old = *address_as_ui; \
  138. unsigned long long newval; \
  139. unsigned long long assumed; \
  140. \
  141. do { \
  142. assumed = old; \
  143. newval = static_cast<uint64_t>(func(val, static_cast<T>(old))); \
  144. old = atomicCAS(address_as_ui, assumed, newval); \
  145. } while (assumed != old); \
  146. } \
  147. };
  148. # define GPU_ATOMIC_INTEGER(NAME, OP, DTYPE) \
  149. inline __device__ void gpuAtomic##NAME(DTYPE *address, DTYPE val) { \
  150. Atomic##NAME##IntegerImpl<DTYPE, sizeof(DTYPE)>()(address, \
  151. val, \
  152. [](DTYPE a, DTYPE b) { \
  153. return OP; \
  154. }); \
  155. } \
  156. ATOMIC_INTEGER_IMPL(Add)
  157. GPU_ATOMIC_INTEGER(Add, a || b, bool)
  158. // Don't instantiate gpuAtomicAdd with the macro as it seems non-standard (see int32, int64)
  159. inline __device__ void gpuAtomicAdd(uint8_t *address, uint8_t val) {
  160. AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address,
  161. val,
  162. [](uint8_t a, uint8_t b) {
  163. return a + b;
  164. });
  165. }
  166. inline __device__ void gpuAtomicAdd(int8_t *address, int8_t val) {
  167. AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address,
  168. val,
  169. [](int8_t a, int8_t b) {
  170. return a + b;
  171. });
  172. }
  173. inline __device__ void gpuAtomicAdd(int16_t *address, int16_t val) {
  174. AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address,
  175. val,
  176. [](int16_t a, int16_t b) {
  177. return a + b;
  178. });
  179. }
  180. inline __device__ int32_t gpuAtomicAdd(int32_t *address, int32_t val) {
  181. return atomicAdd(address, val);
  182. }
  183. inline __device__ void gpuAtomicAdd(int64_t *address, int64_t val) {
  184. #if defined(USE_ROCM)
  185. __atomic_fetch_add(address, val, __ATOMIC_RELAXED);
  186. #else
  187. static_assert(sizeof(unsigned long long int) == sizeof(int64_t), "bitwidth change is not allowed");
  188. atomicAdd(reinterpret_cast<unsigned long long int *>(address), static_cast<unsigned long long int>(val));
  189. #endif
  190. }
  191. inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) {
  192. #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
  193. return AtomicFPOp<at::Half>()(address, val,
  194. [](at::Half hsum, at::Half val) {
  195. return hsum + val;
  196. });
  197. #else
  198. return atomicAdd(reinterpret_cast<__half*>(address), val);
  199. #endif
  200. }
  201. inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) {
  202. #if defined(USE_ROCM) || ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
  203. return AtomicFPOp<at::BFloat16>()(address, val,
  204. [](at::BFloat16 bsum, at::BFloat16 val) {
  205. return bsum + val;
  206. });
  207. #else
  208. __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
  209. return *reinterpret_cast<c10::BFloat16*>(&r);
  210. #endif
  211. }
  212. #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600)
  213. // from CUDA C Programmic Guide
  214. inline __device__ double atomicAdd(double* address, double val)
  215. #if defined(__clang__) && defined(__CUDA__)
  216. #pragma GCC diagnostic push
  217. #pragma GCC diagnostic ignored "-Wgcc-compat"
  218. __attribute__((enable_if(true, "")))
  219. #pragma GCC diagnostic pop
  220. #endif
  221. {
  222. return AtomicFPOp<double>()(address, val,
  223. [](double val, unsigned long long int assumed) {
  224. return __double_as_longlong(val + __longlong_as_double(assumed));
  225. });
  226. }
  227. #elif defined(USE_ROCM) || !(defined(__CUDA_ARCH__))
  228. /* Note [hip-clang differences to hcc]
  229. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  230. * The upcoming hip-clang compiler for ROCm differs from hcc in a few details.
  231. * It exports the __HIP__ macro, we can hence differentiate between hcc and
  232. * hip-clang. In the below, hcc only received support for atomicAdd with double
  233. * typing after work week 18312. hip-clang had support from the first version.
  234. * In general, the code-visible differences between hip-clang and hcc will be
  235. * minimal.
  236. */
  237. #if defined(USE_ROCM) && __hcc_workweek__ < 18312 && !__HIP__
  238. // This needs to be defined for the host side pass
  239. inline __device__ double atomicAdd(double *address, double val) { }
  240. #endif
  241. #endif
  242. inline __device__ double gpuAtomicAdd(double *address, double val) {
  243. return atomicAdd(address, val);
  244. }
  245. inline __device__ float gpuAtomicAdd(float *address, float val) {
  246. return atomicAdd(address, val);
  247. }
  248. template<typename T>
  249. inline __device__ void gpuAtomicAdd(c10::complex<T> *address, c10::complex<T> val) {
  250. gpuAtomicAdd(&address->real_, val.real_);
  251. gpuAtomicAdd(&address->imag_, val.imag_);
  252. }
  253. /* Note [gpuAtomicAdd vs atomicAdd]
  254. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  255. * Some extensions such as torchvision call atomicAdd()
  256. * directly and require non-library provided data type support. Only for these, we
  257. * continue to provide atomicAdd overloads.
  258. */
  259. inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
  260. return gpuAtomicAdd(address, val);
  261. }
  262. inline __device__ at::BFloat16 atomicAdd(at::BFloat16 *address, at::BFloat16 val) {
  263. return gpuAtomicAdd(address, val);
  264. }
  265. inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
  266. gpuAtomicAdd(address, val);
  267. }
  268. inline __device__ void atomicAdd(int8_t *address, int8_t val) {
  269. gpuAtomicAdd(address, val);
  270. }
  271. inline __device__ void atomicAdd(int16_t *address, int16_t val) {
  272. gpuAtomicAdd(address, val);
  273. }
  274. inline __device__ void atomicAdd(int64_t *address, int64_t val) {
  275. gpuAtomicAdd(address, val);
  276. }
  277. inline __device__ void atomicAdd(bool *address, bool val) {
  278. gpuAtomicAdd(address, val);
  279. }
  280. /* Note [explicitly non-returning atomics]
  281. * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  282. * AMD's MI100 (gfx908) provides an optimized fp32 atomicAdd, exposed via atomicAddNoRet().
  283. * Due to compiler limitations, callers must opt-in to guarantee the optimized instruction.
  284. * This non-returning atomicAddNoRet cannot be used to implement the returning atomicAdd,
  285. * therefore we need a new API 'gpuAtomicAddNoReturn'.
  286. */
  287. template<typename T>
  288. inline __device__ void gpuAtomicAddNoReturn(c10::complex<T> *address, c10::complex<T> val) { gpuAtomicAdd(address, val); }
  289. inline __device__ void gpuAtomicAddNoReturn(uint8_t *address, uint8_t val) { gpuAtomicAdd(address, val); }
  290. inline __device__ void gpuAtomicAddNoReturn(int8_t *address, int8_t val) { gpuAtomicAdd(address, val); }
  291. inline __device__ void gpuAtomicAddNoReturn(int16_t *address, int16_t val) { gpuAtomicAdd(address, val); }
  292. inline __device__ void gpuAtomicAddNoReturn(int32_t *address, int32_t val) { gpuAtomicAdd(address, val); }
  293. inline __device__ void gpuAtomicAddNoReturn(int64_t *address, int64_t val) { gpuAtomicAdd(address, val); }
  294. inline __device__ void gpuAtomicAddNoReturn(bool *address, bool val) { gpuAtomicAdd(address, val); }
  295. inline __device__ void gpuAtomicAddNoReturn(at::Half *address, at::Half val) { gpuAtomicAdd(address, val); }
  296. inline __device__ void gpuAtomicAddNoReturn(at::BFloat16 *address, at::BFloat16 val) { gpuAtomicAdd(address, val); }
  297. /* Note [HIP unsafeAtomicAdd]
  298. * ~~~~~~~~~~~~~~~~~~~~~~~~~~
  299. * Use unsafeAtomicAdd instead of atomicAdd for fp32 and fp64.
  300. * On HIP, atomicAdd is always correct but is a slow CAS loop.
  301. * unsafeAtomicAdd will use HW instructions and is much faster,
  302. * but the caller must guarantee the pointer is GPU memory.
  303. * If the pointer is system memory, the result is a silent no-op.
  304. * This guarantee is upheld by all PyTorch uses of unsafeAtomicAdd.
  305. * AMD HIP atomic header file is named amd_hip_atomic.h and is
  306. * under the LLVM compiler directory.
  307. */
  308. #if defined(USE_ROCM)
  309. inline __device__ void gpuAtomicAddNoReturn(float *address, float val) {
  310. #if defined(__gfx908__)
  311. atomicAddNoRet(address, val);
  312. #else
  313. (void)unsafeAtomicAdd(address, val);
  314. #endif
  315. }
  316. inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { (void)unsafeAtomicAdd(address, val); }
  317. #else
  318. inline __device__ void gpuAtomicAddNoReturn(float *address, float val) { gpuAtomicAdd(address, val); }
  319. inline __device__ void gpuAtomicAddNoReturn(double *address, double val) { gpuAtomicAdd(address, val); }
  320. #endif
  321. // Atomic multiplication implementation.
  322. ATOMIC_INTEGER_IMPL(Mul)
  323. GPU_ATOMIC_INTEGER(Mul, a * b, uint8_t)
  324. GPU_ATOMIC_INTEGER(Mul, a * b, int8_t)
  325. GPU_ATOMIC_INTEGER(Mul, a * b, int16_t)
  326. GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
  327. GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)
  328. inline __device__ at::Half gpuAtomicMul(at::Half * address, at::Half val) {
  329. return AtomicFPOp<at::Half>()(address, val,
  330. [](at::Half bsum, at::Half val) {
  331. return bsum * val;
  332. });
  333. }
  334. inline __device__ at::BFloat16 gpuAtomicMul(at::BFloat16 * address, at::BFloat16 val) {
  335. return AtomicFPOp<at::BFloat16>()(address, val,
  336. [](at::BFloat16 bsum, at::BFloat16 val) {
  337. return bsum * val;
  338. });
  339. }
  340. inline __device__ double gpuAtomicMul(double * address, double val) {
  341. return AtomicFPOp<double>()(address, val,
  342. [](double val, unsigned long long int assumed) {
  343. return __double_as_longlong(val * __longlong_as_double(assumed));
  344. });
  345. }
  346. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  347. inline __device__ float gpuAtomicMul (float * address, float val) {
  348. unsigned int* address_as_ull = (unsigned int*)address;
  349. unsigned int old = *address_as_ull;
  350. unsigned int assumed;
  351. do {
  352. assumed = old;
  353. old = atomicCAS(address_as_ull, assumed,
  354. __float_as_int(val *
  355. __int_as_float(assumed)));
  356. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  357. } while (assumed != old);
  358. return __int_as_float(old);
  359. }
  360. // Atomic maximum implementation.
  361. template <typename T>
  362. __host__ __device__ T safe_max(T a, T b) {
  363. #if defined(__HIPCC__)
  364. // TODO: remove this special case for HIP when issue is fixed:
  365. // https://github.com/ROCm/hip/issues/2209
  366. T max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max<T>(a, b));
  367. #else
  368. T max = at::_isnan(b) ? b : std::max<T>(a, b);
  369. #endif
  370. return max;
  371. }
  372. ATOMIC_INTEGER_IMPL(Max)
  373. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), uint8_t)
  374. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t)
  375. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t)
  376. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t)
  377. GPU_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t)
  378. inline __device__ at::Half gpuAtomicMax(at::Half * address, at::Half val) {
  379. return AtomicFPOp<at::Half>()(address, val,
  380. [](at::Half bsum, at::Half val) {
  381. return safe_max(bsum, val);
  382. });
  383. }
  384. inline __device__ at::BFloat16 gpuAtomicMax(at::BFloat16 * address, at::BFloat16 val) {
  385. return AtomicFPOp<at::BFloat16>()(address, val,
  386. [](at::BFloat16 bsum, at::BFloat16 val) {
  387. return safe_max(bsum, val);
  388. });
  389. }
  390. inline __device__ double gpuAtomicMax(double * address, double val) {
  391. return AtomicFPOp<double>()(address, val,
  392. [](double val, unsigned long long int assumed) {
  393. return __double_as_longlong(safe_max(val, __longlong_as_double(assumed)));
  394. });
  395. }
  396. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  397. inline __device__ float gpuAtomicMax(float * address, float val) {
  398. unsigned int* address_as_ull = (unsigned int*)address;
  399. unsigned int old = *address_as_ull;
  400. unsigned int assumed;
  401. do {
  402. assumed = old;
  403. old = atomicCAS(address_as_ull, assumed,
  404. __float_as_int(safe_max(val, __int_as_float(assumed))));
  405. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  406. } while (assumed != old);
  407. return __int_as_float(old);
  408. }
  409. // Atomic minimum implementation.
  410. template <typename T>
  411. __host__ __device__ T safe_min(T a, T b) {
  412. #if defined(__HIPCC__)
  413. // TODO: remove this special case for HIP when issue is fixed:
  414. // https://github.com/ROCm/hip/issues/2209
  415. T min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min<T>(a, b));
  416. #else
  417. T min = at::_isnan(b) ? b : std::min<T>(a, b);
  418. #endif
  419. return min;
  420. }
  421. ATOMIC_INTEGER_IMPL(Min)
  422. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), uint8_t)
  423. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t)
  424. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t)
  425. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t)
  426. GPU_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t)
  427. inline __device__ at::Half gpuAtomicMin(at::Half * address, at::Half val) {
  428. return AtomicFPOp<at::Half>()(address, val,
  429. [](at::Half bsum, at::Half val) {
  430. return safe_min(bsum, val);
  431. });
  432. }
  433. inline __device__ at::BFloat16 gpuAtomicMin(at::BFloat16 * address, at::BFloat16 val) {
  434. return AtomicFPOp<at::BFloat16>()(address, val,
  435. [](at::BFloat16 bsum, at::BFloat16 val) {
  436. return safe_min(bsum, val);
  437. });
  438. }
  439. inline __device__ double gpuAtomicMin(double * address, double val) {
  440. return AtomicFPOp<double>()(address, val,
  441. [](double val, unsigned long long int assumed) {
  442. return __double_as_longlong(safe_min(val, __longlong_as_double(assumed)));
  443. });
  444. }
  445. // Dont use a templated function for this since the addition function defaults to the CUDA built-in.
  446. inline __device__ float gpuAtomicMin(float * address, float val) {
  447. unsigned int* address_as_ull = (unsigned int*)address;
  448. unsigned int old = *address_as_ull;
  449. unsigned int assumed;
  450. do {
  451. assumed = old;
  452. old = atomicCAS(address_as_ull, assumed,
  453. __float_as_int(safe_min(val, __int_as_float(assumed))));
  454. // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
  455. } while (assumed != old);
  456. return __int_as_float(old);
  457. }
  458. #else
  459. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  460. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)