ScanUtils.cuh 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ceil_div.h>
  4. #include <ATen/cuda/DeviceUtils.cuh>
  5. #include <ATen/cuda/AsmUtils.cuh>
  6. #include <c10/macros/Macros.h>
  7. // Collection of in-kernel scan / prefix sum utilities
  8. namespace at::cuda {
  9. // Inclusive prefix sum for binary vars using intra-warp voting +
  10. // shared memory
  11. template <typename T, bool KillWARDependency, class BinaryFunction>
  12. __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
  13. // Within-warp, we use warp voting.
  14. #if defined (USE_ROCM)
  15. unsigned long long int vote = WARP_BALLOT(in);
  16. T index = __popcll(getLaneMaskLe() & vote);
  17. T carry = __popcll(vote);
  18. #else
  19. T vote = WARP_BALLOT(in);
  20. T index = __popc(getLaneMaskLe() & vote);
  21. T carry = __popc(vote);
  22. #endif
  23. int warp = threadIdx.x / C10_WARP_SIZE;
  24. // Per each warp, write out a value
  25. if (getLaneId() == 0) {
  26. smem[warp] = carry;
  27. }
  28. __syncthreads();
  29. // Sum across warps in one thread. This appears to be faster than a
  30. // warp shuffle scan for CC 3.0+
  31. if (threadIdx.x == 0) {
  32. int current = 0;
  33. for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
  34. T v = smem[i];
  35. smem[i] = binop(smem[i], current);
  36. current = binop(current, v);
  37. }
  38. }
  39. __syncthreads();
  40. // load the carry from the preceding warp
  41. if (warp >= 1) {
  42. index = binop(index, smem[warp - 1]);
  43. }
  44. *out = index;
  45. if (KillWARDependency) {
  46. __syncthreads();
  47. }
  48. }
  49. // Exclusive prefix sum for binary vars using intra-warp voting +
  50. // shared memory
  51. template <typename T, bool KillWARDependency, class BinaryFunction>
  52. __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
  53. inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
  54. // Inclusive to exclusive
  55. *out -= (T) in;
  56. // The outgoing carry for all threads is the last warp's sum
  57. *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
  58. if (KillWARDependency) {
  59. __syncthreads();
  60. }
  61. }
  62. } // namespace at::cuda
  63. #else
  64. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  65. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)