cub.cuh 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/cuda/cub.h>
  4. #include <cstddef>
  5. #include <type_traits>
  6. #include <iterator>
  7. #include <limits>
  8. #ifndef USE_ROCM
  9. #include <cuda/std/functional>
  10. #endif
  11. #include <ATen/cuda/cub_definitions.cuh>
  12. #include <ATen/cuda/CUDAContextLight.h>
  13. #if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
  14. #include <cub/cub.cuh>
  15. #else
  16. // include cub in a safe manner, see:
  17. // https://github.com/pytorch/pytorch/pull/55292
  18. #undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
  19. #undef CUB_NS_PREFIX
  20. #undef CUB_NS_QUALIFIER
  21. #define CUB_NS_PREFIX namespace at_cuda_detail {
  22. #define CUB_NS_POSTFIX }
  23. #define CUB_NS_QUALIFIER ::at_cuda_detail::cub
  24. #include <cub/cub.cuh>
  25. #undef CUB_NS_POSTFIX
  26. #undef CUB_NS_PREFIX
  27. #undef CUB_NS_QUALIFIER
  28. #endif
  29. #include <ATen/cuda/Exceptions.h>
  30. #include <c10/cuda/CUDACachingAllocator.h>
  31. #include <c10/cuda/CUDAStream.h>
  32. // handle the temporary storage and 'twice' calls for cub API
  33. #define CUB_WRAPPER(func, ...) do { \
  34. size_t temp_storage_bytes = 0; \
  35. AT_CUDA_CHECK(func(nullptr, temp_storage_bytes, __VA_ARGS__)); \
  36. auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
  37. auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
  38. AT_CUDA_CHECK(func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__));\
  39. } while (false)
  40. #ifdef USE_ROCM
  41. #define NO_ROCM(x)
  42. #define ROCM_HIPCUB(x) ::hipcub
  43. #else
  44. #define NO_ROCM(x) x
  45. #define ROCM_HIPCUB(x) x
  46. #endif
  47. #if CUB_V3_PLUS()
  48. #include <thrust/iterator/transform_iterator.h>
  49. #include <thrust/iterator/counting_iterator.h>
  50. #include <thrust/iterator/constant_iterator.h>
  51. #define ATEN_CUB_TRANSFORM_ITERATOR(ValueType, ...) ::thrust::transform_iterator<__VA_ARGS__>
  52. #define ATEN_CUB_COUNTING_ITERATOR(...) ::thrust::counting_iterator<__VA_ARGS__>
  53. #define ATEN_CUB_CONSTANT_ITERATOR(...) ::thrust::constant_iterator<__VA_ARGS__>
  54. #define ATEN_CUB_MAXIMUM() ::cuda::maximum<>()
  55. #else
  56. #define ATEN_CUB_TRANSFORM_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::TransformInputIterator<__VA_ARGS__>
  57. #define ATEN_CUB_COUNTING_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::CountingInputIterator<__VA_ARGS__>
  58. #define ATEN_CUB_CONSTANT_ITERATOR(...) NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::ConstantInputIterator<__VA_ARGS__>
  59. #define ATEN_CUB_MAXIMUM() NO_ROCM(at_cuda_detail)ROCM_HIPCUB(::cub)::Max()
  60. #endif
  61. #if defined(USE_ROCM)
  62. // backport https://github.com/NVIDIA/cub/pull/306 for c10::BFloat16
  63. template <>
  64. struct ROCM_HIPCUB(cub)::FpLimits<c10::BFloat16>
  65. {
  66. static __host__ __device__ __forceinline__ c10::BFloat16 Max() {
  67. unsigned short max_word = 0x7F7F;
  68. return reinterpret_cast<c10::BFloat16&>(max_word);
  69. }
  70. static __host__ __device__ __forceinline__ c10::BFloat16 Lowest() {
  71. unsigned short lowest_word = 0xFF7F;
  72. return reinterpret_cast<c10::BFloat16&>(lowest_word);
  73. }
  74. };
  75. template <>
  76. struct ROCM_HIPCUB(cub)::NumericTraits<c10::BFloat16>:
  77. ROCM_HIPCUB(cub)::BaseTraits<ROCM_HIPCUB(cub)::FLOATING_POINT, true, false, unsigned short, c10::BFloat16> {};
  78. #endif
  79. #if !defined(USE_ROCM)
  80. namespace at::native {
  81. namespace cub = ::at_cuda_detail::cub;
  82. } // namespace at::native
  83. #endif
  84. namespace at::cuda::cub {
  85. namespace detail {
  86. template<typename T>
  87. struct cuda_type {
  88. using type = T;
  89. };
  90. template<>
  91. struct cuda_type<c10::Half> {
  92. using type = __half;
  93. };
  94. #if !defined(USE_ROCM)
  95. template<>
  96. struct cuda_type<c10::BFloat16> {
  97. using type = __nv_bfloat16;
  98. };
  99. #elif defined(USE_ROCM)
  100. template<>
  101. struct cuda_type<c10::BFloat16> {
  102. using type = hip_bfloat16;
  103. };
  104. #endif
  105. } // namespace detail
  106. template<typename key_t, typename value_t, typename OffsetIteratorT>
  107. inline void segmented_sort_pairs(
  108. const key_t *keys_in, key_t *keys_out,
  109. const value_t *values_in, value_t *values_out,
  110. int64_t num_elements, int64_t num_segments,
  111. OffsetIteratorT begin_offsets, OffsetIteratorT end_offsets,
  112. bool descending=false, int64_t begin_bit=0, int64_t end_bit=sizeof(key_t)*8
  113. ) {
  114. TORCH_CHECK(num_elements <= std::numeric_limits<int>::max(),
  115. "cub sort does not support sorting more than INT_MAX elements");
  116. TORCH_CHECK(num_segments <= std::numeric_limits<int>::max(),
  117. "cub sort does not support sorting more than INT_MAX elements");
  118. using key_t_ = typename detail::cuda_type<key_t>::type;
  119. auto allocator = c10::cuda::CUDACachingAllocator::get();
  120. c10::DataPtr keys_out_owner;
  121. if (keys_out == nullptr) {
  122. keys_out_owner = allocator->allocate(num_elements * sizeof(key_t));
  123. keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
  124. }
  125. const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
  126. key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
  127. if (descending) {
  128. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairsDescending,
  129. keys_in_, keys_out_, values_in, values_out,
  130. num_elements, num_segments, begin_offsets, end_offsets,
  131. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  132. } else {
  133. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSegmentedRadixSort::SortPairs,
  134. keys_in_, keys_out_, values_in, values_out,
  135. num_elements, num_segments, begin_offsets, end_offsets,
  136. begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
  137. }
  138. }
  139. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
  140. inline void unique_by_key(
  141. KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
  142. ValuesOutputIteratorT values_out,
  143. NumSelectedIteratorT num_selected, int64_t num_input_items)
  144. {
  145. // TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
  146. using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
  147. auto allocator = c10::cuda::CUDACachingAllocator::get();
  148. c10::DataPtr keys_out_owner;
  149. keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
  150. auto keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
  151. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
  152. keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
  153. }
  154. namespace impl {
  155. template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
  156. C10_LAUNCH_BOUNDS_1(1)
  157. __global__ void transform_vals(InputIteratorT1 a, InputIteratorT2 b, OutputIteratorT out, ScanOpT scan_op){
  158. // NOTE: out here not the final scan output, but an intermediate of the accumulation type.
  159. using acc_t = typename std::iterator_traits<OutputIteratorT>::value_type;
  160. *out = scan_op(static_cast<acc_t>(*a), static_cast<acc_t>(*b));
  161. }
  162. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  163. // so split at int_max/2
  164. constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
  165. }
  166. // non synchronizing cub call
  167. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  168. // so split at int_max/2
  169. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, int max_cub_size=impl::max_cub_size>
  170. inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  171. #if defined(USE_ROCM)
  172. //For ROCm, use hipCUB chained iterators
  173. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::InclusiveScan,
  174. input,
  175. output,
  176. scan_op,
  177. num_items,
  178. at::cuda::getCurrentCUDAStream());
  179. C10_CUDA_KERNEL_LAUNCH_CHECK();
  180. #else
  181. // non synchronizing cub call
  182. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  183. // so split at int_max/2
  184. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  185. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::InclusiveScan,
  186. input,
  187. output,
  188. scan_op,
  189. size_cub,
  190. at::cuda::getCurrentCUDAStream());
  191. C10_CUDA_KERNEL_LAUNCH_CHECK();
  192. using input_t = typename std::iterator_traits<InputIteratorT>::value_type;
  193. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  194. auto allocator = c10::cuda::CUDACachingAllocator::get();
  195. c10::DataPtr first_elem = allocator->allocate(sizeof(input_t));
  196. auto first_elem_ptr = reinterpret_cast<input_t *>(first_elem.get());
  197. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  198. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  199. output + i - 1,
  200. input + i,
  201. first_elem_ptr,
  202. scan_op);
  203. C10_CUDA_KERNEL_LAUNCH_CHECK();
  204. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  205. input + i + 1,
  206. output + i,
  207. scan_op,
  208. ::at_cuda_detail::cub::FutureValue<input_t>(first_elem_ptr),
  209. size_cub,
  210. at::cuda::getCurrentCUDAStream());
  211. }
  212. #endif
  213. }
  214. template<typename T>
  215. struct BlockPrefixCallbackOp
  216. {
  217. public:
  218. T running_total;
  219. __host__ __device__ BlockPrefixCallbackOp(T running_total) : running_total(running_total) {}
  220. // Callback operator to be entered by the first warp of threads in the block.
  221. // Thread-0 is responsible for returning a value for seeding the block-wide scan.
  222. __host__ __device__ T operator()(T block_aggregate)
  223. {
  224. T old_prefix = running_total;
  225. running_total += block_aggregate;
  226. return old_prefix;
  227. }
  228. };
  229. template<int BLOCK_THREADS, int ITEMS_PER_THREAD, typename T>
  230. __global__ void final_scan_kernel(const T* d_in, T* d_out, T* agg, int64_t nelem, int iters_per_cta) {
  231. int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
  232. int64_t remaining = nelem - offset;
  233. if (remaining <= 0) {
  234. return;
  235. }
  236. d_in += offset;
  237. d_out += offset;
  238. using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_WARP_TRANSPOSE>;
  239. // Specialize BlockStore type for our thread block (uses warp-striped loads for coalescing, then transposes in shared
  240. // memory to a blocked arrangement)
  241. using BlockStoreT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_STORE_WARP_TRANSPOSE>;
  242. // Specialize BlockScan type for our thread block
  243. using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan<T, BLOCK_THREADS, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_SCAN_WARP_SCANS>;
  244. using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<T, BLOCK_THREADS>;
  245. // Shared memory
  246. __shared__ union TempStorage
  247. {
  248. typename BlockLoadT::TempStorage load;
  249. typename BlockStoreT::TempStorage store;
  250. typename BlockScanT::TempStorage scan;
  251. typename BlockReduceT::TempStorage reduce;
  252. } temp_storage;
  253. // load agg and reduce my starting value
  254. T agg_data;
  255. agg_data = threadIdx.x >= blockIdx.x ? T(0) : agg[threadIdx.x];
  256. // In case there are fewer threads than previous block aggregates to be read, add more aggregates (should be at most 2-3 aggregates per thread)
  257. for (unsigned int i=threadIdx.x + blockDim.x; i<blockIdx.x; i+=blockDim.x) {
  258. agg_data += agg[i];
  259. }
  260. T aggregate = BlockReduceT(temp_storage.reduce).Sum(agg_data);
  261. __syncthreads();
  262. BlockPrefixCallbackOp prefix_op(aggregate);
  263. // Per-thread tile data
  264. T data[ITEMS_PER_THREAD];
  265. for (int i=0; i<iters_per_cta; i++){
  266. // Load items into a blocked arrangement
  267. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  268. BlockLoadT(temp_storage.load).Load(d_in, data);
  269. } else {
  270. #pragma unroll
  271. for (int j=0; j<ITEMS_PER_THREAD; j++) {
  272. data[j] = 0;
  273. }
  274. BlockLoadT(temp_storage.load).Load(d_in, data, remaining);
  275. }
  276. // Barrier for smem reuse
  277. __syncthreads();
  278. // Compute inclusive prefix sum
  279. BlockScanT(temp_storage.scan).InclusiveSum(data, data, prefix_op);
  280. // Barrier for smem reuse
  281. __syncthreads();
  282. // Store items from a blocked arrangement
  283. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  284. BlockStoreT(temp_storage.store).Store(d_out, data);
  285. } else {
  286. BlockStoreT(temp_storage.store).Store(d_out, data, remaining);
  287. }
  288. d_in += BLOCK_THREADS * ITEMS_PER_THREAD;
  289. d_out += BLOCK_THREADS * ITEMS_PER_THREAD;
  290. remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
  291. if (remaining <= 0) return;
  292. __syncthreads();
  293. }
  294. }
  295. template <typename T, typename aggT, bool nonzero>
  296. struct TransformFunctor {
  297. __device__ aggT operator()(T value) const {
  298. if constexpr (!nonzero) {
  299. return value;
  300. } else {
  301. return (value != T(0)) ? 1 : 0;
  302. }
  303. }
  304. };
  305. template<int BLOCK_THREADS, int ITEMS_PER_THREAD, bool nonzero, typename T, typename aggT>
  306. __global__ void calc_block_sums(const T * d_in, aggT * agg, int64_t nelem, int iters_per_cta){
  307. int64_t offset = BLOCK_THREADS * ITEMS_PER_THREAD * iters_per_cta * (int64_t)blockIdx.x;
  308. int64_t remaining = nelem - offset;
  309. if (remaining <= 0) {
  310. return;
  311. }
  312. d_in += offset;
  313. using BlockLoadT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockLoad<aggT, BLOCK_THREADS, ITEMS_PER_THREAD, ROCM_HIPCUB(at_cuda_detail::cub)::BLOCK_LOAD_STRIPED>;
  314. using BlockReduceT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockReduce<aggT, BLOCK_THREADS>;
  315. // Shared memory
  316. __shared__ union TempStorage
  317. {
  318. typename BlockLoadT::TempStorage load;
  319. typename BlockReduceT::TempStorage reduce;
  320. } temp_storage;
  321. aggT data[ITEMS_PER_THREAD];
  322. aggT agg_val = 0;
  323. TransformFunctor<T, aggT, nonzero> transform_functor;
  324. auto iter_in = ATEN_CUB_TRANSFORM_ITERATOR(aggT, TransformFunctor<T, aggT, nonzero>, const T*)(d_in, transform_functor);
  325. for (int i=0; i<iters_per_cta; i++){
  326. if (remaining >= BLOCK_THREADS * ITEMS_PER_THREAD) {
  327. BlockLoadT(temp_storage.load).Load(iter_in, data);
  328. __syncthreads();
  329. agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
  330. } else {
  331. BlockLoadT(temp_storage.load).Load(iter_in, data, remaining, aggT(0));
  332. __syncthreads();
  333. agg_val += BlockReduceT(temp_storage.reduce).Sum(data);
  334. }
  335. iter_in += BLOCK_THREADS * ITEMS_PER_THREAD;
  336. remaining -= BLOCK_THREADS * ITEMS_PER_THREAD;
  337. if (remaining <= 0) {
  338. // for nonzeros we need to write out last blocks
  339. // accumulated value to be able to compute
  340. // total number of nonzeros
  341. if (nonzero && threadIdx.x == 0) {
  342. agg[blockIdx.x] = agg_val;
  343. }
  344. return;
  345. }
  346. __syncthreads();
  347. }
  348. if (threadIdx.x == 0) {
  349. agg[blockIdx.x] = agg_val;
  350. }
  351. }
  352. template <typename T>
  353. struct NonZeroOp {
  354. __host__ __device__ __forceinline__ int operator()(const T& a) const {
  355. return (a != T(0));
  356. }
  357. };
  358. template<int size>
  359. constexpr int block_threads(){
  360. if constexpr (size >=16) {
  361. return 128;
  362. } else if constexpr (size >=8) {
  363. return 256;
  364. } else {
  365. return 512;
  366. }
  367. }
  368. template<typename scalar_t, typename ScanOpT>
  369. inline void inclusive_deterministic_scan(const scalar_t * input, scalar_t * output, ScanOpT scan_op, int64_t num_items) {
  370. static_assert(std::is_same_v<ScanOpT, std::plus<scalar_t>>, "");
  371. constexpr int BLOCK_THREADS = block_threads<sizeof(scalar_t)>();
  372. constexpr int ITEMS_PER_THREAD = 16;
  373. auto grid_size = (num_items + BLOCK_THREADS * ITEMS_PER_THREAD - 1) / (BLOCK_THREADS * ITEMS_PER_THREAD);
  374. const int64_t num_sms = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  375. const int iters_per_cta = (grid_size + num_sms - 1)/num_sms;
  376. grid_size = std::min(num_sms, grid_size);
  377. auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  378. auto agg = allocator.allocate(grid_size * sizeof(scalar_t));
  379. calc_block_sums<BLOCK_THREADS, ITEMS_PER_THREAD, false>
  380. <<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
  381. input, (scalar_t*)agg.get(), num_items, iters_per_cta);
  382. C10_CUDA_KERNEL_LAUNCH_CHECK();
  383. final_scan_kernel<BLOCK_THREADS, ITEMS_PER_THREAD>
  384. <<<grid_size, BLOCK_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
  385. input, output, (scalar_t*)agg.get(), num_items, iters_per_cta);
  386. C10_CUDA_KERNEL_LAUNCH_CHECK();
  387. }
  388. template<typename InputIteratorT, typename OutputIteratorT, typename ScanOpT, typename InitValueT, int max_cub_size=impl::max_cub_size>
  389. inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT scan_op, InitValueT init_value, int64_t num_items) {
  390. #if defined(USE_ROCM)
  391. //For ROCm, use hipCUB chained iterators
  392. CUB_WRAPPER(NO_ROCM(detail)::hipcub::DeviceScan::ExclusiveScan,
  393. input,
  394. output,
  395. scan_op,
  396. init_value,
  397. num_items,
  398. at::cuda::getCurrentCUDAStream());
  399. C10_CUDA_KERNEL_LAUNCH_CHECK();
  400. #else
  401. // non synchronizing cub call
  402. // even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
  403. // so split at int_max/2
  404. int size_cub = std::min<int64_t>(num_items, max_cub_size);
  405. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  406. input,
  407. output,
  408. scan_op,
  409. init_value,
  410. size_cub,
  411. at::cuda::getCurrentCUDAStream());
  412. C10_CUDA_KERNEL_LAUNCH_CHECK();
  413. for (int64_t i = max_cub_size; i < num_items; i += max_cub_size) {
  414. auto allocator = c10::cuda::CUDACachingAllocator::get();
  415. c10::DataPtr first_elem = allocator->allocate(sizeof(InitValueT));
  416. auto first_elem_ptr = reinterpret_cast<InitValueT *>(first_elem.get());
  417. size_cub = std::min<int64_t>(num_items - i, max_cub_size);
  418. impl::transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  419. output + i - 1,
  420. input + i - 1,
  421. first_elem_ptr,
  422. scan_op);
  423. C10_CUDA_KERNEL_LAUNCH_CHECK();
  424. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceScan::ExclusiveScan,
  425. input + i,
  426. output + i,
  427. scan_op,
  428. ::at_cuda_detail::cub::FutureValue<InitValueT>(first_elem_ptr),
  429. size_cub,
  430. at::cuda::getCurrentCUDAStream());
  431. }
  432. #endif
  433. }
  434. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
  435. inline void inclusive_sum_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items) {
  436. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  437. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  438. #if !defined(USE_ROCM)
  439. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveSumByKey,
  440. keys, input, output, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
  441. #else
  442. CUB_WRAPPER(cub::DeviceScan::InclusiveSumByKey,
  443. keys, input, output, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
  444. #endif
  445. }
  446. template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT, typename ScanOpT>
  447. inline void inclusive_scan_by_key(KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, ScanOpT scan_op, int64_t num_items) {
  448. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  449. "cub InclusiveSumByKey does not support more than INT_MAX elements");
  450. #if !defined(USE_ROCM)
  451. CUB_WRAPPER(at_cuda_detail::cub::DeviceScan::InclusiveScanByKey,
  452. keys, input, output, scan_op, num_items, NO_ROCM(::cuda)::std::equal_to<>(), at::cuda::getCurrentCUDAStream());
  453. #else
  454. CUB_WRAPPER(cub::DeviceScan::InclusiveScanByKey,
  455. keys, input, output, scan_op, num_items, hipcub::Equality(), at::cuda::getCurrentCUDAStream());
  456. #endif
  457. }
  458. template <typename InputIteratorT, typename OutputIteratorT, typename NumSelectedIteratorT>
  459. void unique(InputIteratorT input, OutputIteratorT output,
  460. NumSelectedIteratorT num_selected_out, int64_t num_items) {
  461. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  462. "cub unique does not support more than INT_MAX elements");
  463. CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
  464. input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
  465. }
  466. template <typename InputIteratorT, typename OutputIteratorT, typename CountsOutputIteratorT,
  467. typename LengthOutputIteratorT>
  468. void run_length_encode(InputIteratorT input, OutputIteratorT output, CountsOutputIteratorT counts_out,
  469. LengthOutputIteratorT length_out, int64_t num_items) {
  470. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  471. "cub run_length_encode does not support more than INT_MAX elements");
  472. CUB_WRAPPER(
  473. NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
  474. input, output, counts_out, length_out, num_items,
  475. at::cuda::getCurrentCUDAStream());
  476. }
  477. template <typename InputIteratorT, typename OutputIteratorT, typename ReductionOpT, typename T>
  478. void reduce(InputIteratorT input, OutputIteratorT output, int64_t num_items, ReductionOpT op, T init) {
  479. TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
  480. "cub reduce does not support more than INT_MAX elements");
  481. CUB_WRAPPER(
  482. NO_ROCM(at_cuda_detail)::cub::DeviceReduce::Reduce,
  483. input, output, num_items, op, init,
  484. at::cuda::getCurrentCUDAStream());
  485. }
  486. } // namespace at::cuda::cub
  487. #else
  488. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  489. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)