cub.h 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cstdint>
  4. #include <c10/core/ScalarType.h>
  5. #include <ATen/cuda/CUDAConfig.h>
  6. // NOTE: These templates are intentionally not defined in this header,
  7. // which avoids re-compiling them for each translation unit. If you get
  8. // a link error, you need to add an explicit instantiation for your
  9. // types in cub.cu
  10. namespace at::cuda::cub {
  11. inline int get_num_bits(uint64_t max_key) {
  12. int num_bits = 1;
  13. while (max_key > 1) {
  14. max_key >>= 1;
  15. num_bits++;
  16. }
  17. return num_bits;
  18. }
  19. namespace detail {
  20. // radix_sort_pairs doesn't interact with value_t other than to copy
  21. // the data, so we can save template instantiations by reinterpreting
  22. // it as an opaque type.
  23. // We use native integer types for 1/2/4/8-byte values to reduce
  24. // register usage in CUDA kernels. For sizes > 8 fall back to char array.
  25. template <int N> struct alignas(N) OpaqueType { char data[N]; };
  26. template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
  27. template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
  28. template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
  29. template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
  30. template<typename key_t, int value_size>
  31. void radix_sort_pairs_impl(
  32. const key_t *keys_in, key_t *keys_out,
  33. const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,
  34. int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);
  35. } // namespace detail
  36. template<typename key_t, typename value_t>
  37. void radix_sort_pairs(
  38. const key_t *keys_in, key_t *keys_out,
  39. const value_t *values_in, value_t *values_out,
  40. int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8) {
  41. static_assert(std::is_trivially_copyable_v<value_t> ||
  42. AT_ROCM_ENABLED(), // ROCm incorrectly fails this check for vector types
  43. "radix_sort_pairs value type must be trivially copyable");
  44. // Make value type opaque, so all inputs of a certain size use the same template instantiation
  45. using opaque_t = detail::OpaqueType<sizeof(value_t)>;
  46. static_assert(sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
  47. "This size of value_t is not instantiated. Please instantiate it in cub.cu"
  48. " and modify this check.");
  49. static_assert(sizeof(value_t) == alignof(value_t), "Expected value_t to be size-aligned");
  50. detail::radix_sort_pairs_impl(
  51. keys_in, keys_out,
  52. reinterpret_cast<const opaque_t*>(values_in),
  53. reinterpret_cast<opaque_t*>(values_out),
  54. n, descending, begin_bit, end_bit);
  55. }
  56. template<typename key_t>
  57. void radix_sort_keys(
  58. const key_t *keys_in, key_t *keys_out,
  59. int64_t n, bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8);
  60. // NOTE: Intermediate sums will be truncated to input_t precision
  61. template <typename input_t, typename output_t>
  62. void inclusive_sum_truncating(const input_t *input, output_t *output, int64_t n);
  63. template <typename scalar_t>
  64. void inclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
  65. return inclusive_sum_truncating(input, output, n);
  66. }
  67. // NOTE: Sums are done is common_type<input_t, output_t>
  68. template <typename input_t, typename output_t>
  69. void exclusive_sum_in_common_type(const input_t *input, output_t *output, int64_t n);
  70. template <typename scalar_t>
  71. void exclusive_sum(const scalar_t *input, scalar_t *output, int64_t n) {
  72. return exclusive_sum_in_common_type(input, output, n);
  73. }
  74. void mask_exclusive_sum(const uint8_t *mask, int64_t *output_idx, int64_t n);
  75. inline void mask_exclusive_sum(const bool *mask, int64_t *output_idx, int64_t n) {
  76. return mask_exclusive_sum(
  77. reinterpret_cast<const uint8_t*>(mask), output_idx, n);
  78. }
  79. } // namespace at::cuda::cub
  80. #else
  81. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  82. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)