DeviceUtils.cuh 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cuda.h>
  4. #include <c10/util/complex.h>
  5. #include <c10/util/Half.h>
  6. __device__ __forceinline__ unsigned int ACTIVE_MASK()
  7. {
  8. #if !defined(USE_ROCM)
  9. return __activemask();
  10. #else
  11. // will be ignored anyway
  12. return 0xffffffff;
  13. #endif
  14. }
  15. __device__ __forceinline__ void WARP_SYNC(unsigned mask = 0xffffffff) {
  16. #if !defined(USE_ROCM)
  17. return __syncwarp(mask);
  18. #endif
  19. }
  20. #if defined(USE_ROCM)
  21. __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
  22. {
  23. return __ballot(predicate);
  24. }
  25. #else
  26. __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
  27. {
  28. #if !defined(USE_ROCM)
  29. return __ballot_sync(mask, predicate);
  30. #else
  31. return __ballot(predicate);
  32. #endif
  33. }
  34. #endif
  35. template <typename T>
  36. __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  37. {
  38. #if !defined(USE_ROCM)
  39. return __shfl_xor_sync(mask, value, laneMask, width);
  40. #else
  41. return __shfl_xor(value, laneMask, width);
  42. #endif
  43. }
  44. template <typename T>
  45. __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
  46. {
  47. #if !defined(USE_ROCM)
  48. return __shfl_sync(mask, value, srcLane, width);
  49. #else
  50. return __shfl(value, srcLane, width);
  51. #endif
  52. }
  53. template <typename T>
  54. __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  55. {
  56. #if !defined(USE_ROCM)
  57. return __shfl_up_sync(mask, value, delta, width);
  58. #else
  59. return __shfl_up(value, delta, width);
  60. #endif
  61. }
  62. template <typename T>
  63. __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  64. {
  65. #if !defined(USE_ROCM)
  66. return __shfl_down_sync(mask, value, delta, width);
  67. #else
  68. return __shfl_down(value, delta, width);
  69. #endif
  70. }
  71. #if defined(USE_ROCM)
  72. template<>
  73. __device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
  74. {
  75. //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
  76. int2 a = *reinterpret_cast<int2*>(&value);
  77. a.x = __shfl_down(a.x, delta);
  78. a.y = __shfl_down(a.y, delta);
  79. return *reinterpret_cast<int64_t*>(&a);
  80. }
  81. #endif
  82. template<>
  83. __device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
  84. {
  85. return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
  86. }
  87. template <typename T>
  88. __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  89. {
  90. #if !defined(USE_ROCM)
  91. return c10::complex<T>(
  92. __shfl_down_sync(mask, value.real_, delta, width),
  93. __shfl_down_sync(mask, value.imag_, delta, width));
  94. #else
  95. return c10::complex<T>(
  96. __shfl_down(value.real_, delta, width),
  97. __shfl_down(value.imag_, delta, width));
  98. #endif
  99. }
  100. /**
  101. * For CC 3.5+, perform a load using __ldg
  102. */
  103. template <typename T>
  104. __device__ __forceinline__ T doLdg(const T* p) {
  105. #if !defined(USE_ROCM)
  106. return __ldg(p);
  107. #else
  108. return *p;
  109. #endif
  110. }
  111. #else
  112. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  113. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)