SharedReduceOps.h 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. // Please note that this file is
  4. // used across both CPU and GPU.
  5. #include <type_traits>
  6. #include <complex>
  7. #include <c10/macros/Macros.h>
  8. #include <ATen/detail/FunctionTraits.h>
  9. #include <ATen/NumericUtils.h>
  10. #include <ATen/OpMathType.h>
  11. #if defined(__CUDACC__)
  12. #include <ATen/cuda/DeviceUtils.cuh>
  13. #include <ATen/native/cuda/DeviceSqrt.cuh>
  14. #elif defined(__HIPCC__)
  15. #include <ATen/hip/DeviceUtils.cuh>
  16. #include <ATen/native/hip/DeviceSqrt.cuh>
  17. #endif
  18. #if defined(__CUDACC__) || defined(__HIPCC__)
  19. #include <thrust/pair.h>
  20. #else
  21. #include <cmath>
  22. #define device_sqrt std::sqrt
  23. #endif
  24. #if defined(__CUDACC__) || defined(__HIPCC__)
  25. template <typename scalar_t>
  26. inline C10_DEVICE scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
  27. #if defined(__HIPCC__)
  28. // TODO: remove this special case for HIP when issue is fixed:
  29. // https://github.com/ROCm/hip/issues/2209
  30. scalar_t max = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::max(a, b));
  31. #else
  32. scalar_t max = at::_isnan(b) ? b : std::max(a, b);
  33. #endif
  34. return max;
  35. }
  36. template <typename scalar_t>
  37. inline C10_DEVICE scalar_t min_propagate_nan(scalar_t a, scalar_t b) {
  38. #if defined(__HIPCC__)
  39. // TODO: remove this special case for HIP when issue is fixed:
  40. // https://github.com/ROCm/hip/issues/2209
  41. scalar_t min = at::_isnan(a) ? a : (at::_isnan(b) ? b : std::min(a, b));
  42. #else
  43. scalar_t min = at::_isnan(b) ? b : std::min(a, b);
  44. #endif
  45. return min;
  46. }
  47. #define MAX(X, Y) max_propagate_nan(X,Y)
  48. #define MIN(X, Y) min_propagate_nan(X,Y)
  49. #else
  50. #include <ATen/native/cpu/zmath.h>
  51. #define MAX(X, Y) max_impl(X,Y)
  52. #define MIN(X, Y) min_impl(X,Y)
  53. #endif
  54. // ROCm hip compiler doesn't work well with using std:: in kernel functions
  55. #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
  56. #if defined(__CUDA_ARCH__)
  57. #include <c10/cuda/CUDAMathCompat.h>
  58. #elif defined(__HIPCC__)
  59. #include <c10/hip/HIPMathCompat.h>
  60. #endif
  61. #define compat_pow c10::cuda::compat::pow
  62. #else
  63. #define compat_pow std::pow
  64. #endif
  65. namespace at::native {
  66. namespace detail {
  67. #if defined(__CUDACC__) || defined(__HIPCC__)
  68. template <typename T1, typename T2> using pair = thrust::pair<T1, T2>;
  69. #else
  70. template <typename T1, typename T2> using pair = std::pair<T1, T2>;
  71. #endif
  72. } // namespace detail
  73. template <typename scalar_t, typename index_t>
  74. struct WelfordData {
  75. scalar_t mean;
  76. scalar_t m2;
  77. index_t n;
  78. scalar_t nf;
  79. C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
  80. C10_HOST_DEVICE WelfordData(
  81. scalar_t mean,
  82. scalar_t m2,
  83. index_t n,
  84. scalar_t nf)
  85. : mean(mean), m2(m2), n(n), nf(nf) {}
  86. };
  87. template <typename scalar_t, typename acc_scalar_t, typename index_t, typename res_t>
  88. struct WelfordOps {
  89. acc_scalar_t correction;
  90. bool take_sqrt;
  91. public:
  92. using acc_t = WelfordData<acc_scalar_t, index_t>;
  93. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
  94. // We accumulate n in index_t to avoid cumulative rounding error, but still
  95. // need nf for use in combine where int32 may overflow.
  96. index_t new_n = acc.n + 1;
  97. acc_scalar_t new_nf = static_cast<acc_scalar_t>(new_n);
  98. acc_scalar_t delta = data - acc.mean;
  99. acc_scalar_t new_mean = acc.mean + delta / new_nf;
  100. acc_scalar_t new_delta = data - new_mean;
  101. return {
  102. new_mean,
  103. acc.m2 + delta * new_delta,
  104. new_n,
  105. new_nf,
  106. };
  107. }
  108. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  109. if (a.nf == 0) {
  110. return b;
  111. }
  112. if (b.nf == 0) {
  113. return a;
  114. }
  115. acc_scalar_t delta = b.mean - a.mean;
  116. acc_scalar_t new_count = a.nf + b.nf;
  117. acc_scalar_t nb_over_n = b.nf / new_count;
  118. return {
  119. a.mean + delta * nb_over_n,
  120. a.m2 + b.m2 + delta * delta * a.nf * nb_over_n,
  121. // setting acc.n as -1 since acc.n might not be able to represent the count
  122. // correctly within its range, setting it to -1 to avoid confusion
  123. -1,
  124. new_count
  125. };
  126. }
  127. inline C10_DEVICE res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ {
  128. const auto mean = static_cast<scalar_t>(acc.mean);
  129. const auto divisor = acc.nf > correction ? acc.nf - correction : 0;
  130. const auto var = acc.m2 / divisor;
  131. res_t results(take_sqrt ? device_sqrt(var) : var, mean);
  132. return results;
  133. }
  134. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  135. return acc;
  136. }
  137. #if defined(__CUDACC__) || defined(__HIPCC__)
  138. inline __device__ acc_t warp_shfl_down(acc_t acc, int offset) const {
  139. return {
  140. WARP_SHFL_DOWN(acc.mean, offset)
  141. , WARP_SHFL_DOWN(acc.m2, offset)
  142. , WARP_SHFL_DOWN(acc.n, offset)
  143. , WARP_SHFL_DOWN(acc.nf, offset)
  144. };
  145. }
  146. #endif
  147. C10_HOST_DEVICE WelfordOps(acc_scalar_t correction, bool take_sqrt)
  148. : correction(correction), take_sqrt(take_sqrt) {}
  149. };
  150. template <typename scalar_t, typename acc_t=scalar_t, typename factor_t=acc_t, typename out_t = acc_t>
  151. struct MeanOps {
  152. factor_t factor;
  153. inline C10_DEVICE acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const {
  154. return combine(a, static_cast<acc_t>(b));
  155. }
  156. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  157. return a + b;
  158. }
  159. inline C10_DEVICE out_t project(acc_t a) const {
  160. return a * factor;
  161. }
  162. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  163. return acc;
  164. }
  165. #if defined(__CUDACC__) || defined(__HIPCC__)
  166. inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
  167. return WARP_SHFL_DOWN(data, offset);
  168. }
  169. #endif
  170. MeanOps(factor_t factor): factor(factor) {
  171. }
  172. };
  173. // This accumulator template is used to calculate the minimum absolute value of
  174. // a set of numbers.
  175. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  176. // value. These types differ for complex number input support.
  177. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  178. struct AbsMinOps {
  179. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  180. return MIN(acc, static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))));
  181. }
  182. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  183. return MIN(a, b);
  184. }
  185. inline C10_DEVICE out_t project(acc_t a) const {
  186. return a;
  187. }
  188. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  189. return acc;
  190. }
  191. #if defined(__CUDACC__) || defined(__HIPCC__)
  192. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  193. return WARP_SHFL_DOWN(acc, offset);
  194. }
  195. #endif
  196. };
  197. // This accumulator template is used to calculate the maximum absolute value of
  198. // a set of numbers.
  199. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  200. // value. These types differ for complex number input support.
  201. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  202. struct AbsMaxOps {
  203. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  204. return MAX(acc, static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))));
  205. }
  206. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  207. return MAX(a, b);
  208. }
  209. inline C10_DEVICE out_t project(acc_t a) const {
  210. return a;
  211. }
  212. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  213. return acc;
  214. }
  215. #if defined(__CUDACC__) || defined(__HIPCC__)
  216. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  217. return WARP_SHFL_DOWN(acc, offset);
  218. }
  219. #endif
  220. };
  221. // This accumulator template is used to calculate the norm of the absolute value
  222. // of a set of numbers.
  223. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  224. // value. These types differ for complex number input support.
  225. // `apply_root` controls whether to apply the final root: if true, returns
  226. // (sum(|x|^p))^(1/p); if false, returns sum(|x|^p) (used by linalg._powsum).
  227. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t, bool apply_root = true>
  228. struct NormOps {
  229. acc_t norm_;
  230. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  231. return acc + compat_pow(static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data))), norm_);
  232. }
  233. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  234. return a + b;
  235. }
  236. inline C10_DEVICE out_t project(acc_t a) const {
  237. if constexpr (apply_root) {
  238. return compat_pow(a, static_cast<acc_t>(1.0) / norm_);
  239. } else {
  240. return a;
  241. }
  242. }
  243. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  244. return acc;
  245. }
  246. #if defined(__CUDACC__) || defined(__HIPCC__)
  247. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  248. return WARP_SHFL_DOWN(acc, offset);
  249. }
  250. #endif
  251. NormOps(acc_t norm_): norm_(norm_) {
  252. }
  253. };
  254. // This accumulator template is used to calculate the order zero norm of the
  255. // absolute value of a set of numbers.
  256. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  257. // value. These types differ for complex number input support.
  258. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  259. struct NormZeroOps {
  260. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  261. return acc + (data == static_cast<scalar_t>(0) ? static_cast<acc_t>(0) : static_cast<acc_t>(1));
  262. }
  263. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  264. return a + b;
  265. }
  266. inline C10_DEVICE out_t project(acc_t a) const {
  267. return a;
  268. }
  269. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  270. return acc;
  271. }
  272. #if defined(__CUDACC__) || defined(__HIPCC__)
  273. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  274. return WARP_SHFL_DOWN(acc, offset);
  275. }
  276. #endif
  277. };
  278. // This accumulator template is used to calculate the order one norm of the
  279. // absolute value of a set of numbers.
  280. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  281. // value. These types differ for complex number input support.
  282. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t>
  283. struct NormOneOps {
  284. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  285. return acc + static_cast<acc_t>(std::abs(at::opmath_type<scalar_t>(data)));
  286. }
  287. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  288. return a + b;
  289. }
  290. inline C10_DEVICE out_t project(acc_t a) const {
  291. return a;
  292. }
  293. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  294. return acc;
  295. }
  296. #if defined(__CUDACC__) || defined(__HIPCC__)
  297. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  298. return WARP_SHFL_DOWN(acc, offset);
  299. }
  300. #endif
  301. };
  302. template<typename acc_t>
  303. struct AbsSwitch {};
  304. template<typename scalar_t, typename acc_t>
  305. inline C10_DEVICE acc_t abs_if_complex(scalar_t data, AbsSwitch<acc_t> /*unused*/) {
  306. return static_cast<acc_t>(data);
  307. }
  308. template<typename scalar_t, typename acc_t>
  309. inline C10_DEVICE acc_t abs_if_complex(std::complex<scalar_t> data, AbsSwitch<acc_t> /*unused*/) {
  310. return static_cast<acc_t>(std::abs(data));
  311. }
  312. template<typename scalar_t, typename acc_t>
  313. inline C10_DEVICE acc_t abs_if_complex(c10::complex<scalar_t> data, AbsSwitch<acc_t> /*unused*/) {
  314. return static_cast<acc_t>(std::abs(at::opmath_type<c10::complex<scalar_t>>(data)));
  315. }
  316. // This accumulator template is used to calculate the order two norm of the
  317. // absolute value of a set of numbers.
  318. // `scalar_t` is the type of the input and `acc_t` is the type of the accumulated
  319. // value. These types differ for complex number input support.
  320. // `apply_root` controls whether to apply the final sqrt: if true, returns
  321. // sqrt(sum(|x|^2)); if false, returns sum(|x|^2) (used by linalg._powsum).
  322. template <typename scalar_t, typename acc_t = scalar_t, typename out_t = acc_t, bool apply_root = true>
  323. struct NormTwoOps {
  324. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const {
  325. acc_t data_ = abs_if_complex(data, AbsSwitch<acc_t>());
  326. return acc + data_ * data_;
  327. }
  328. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  329. return a + b;
  330. }
  331. inline C10_DEVICE out_t project(acc_t a) const {
  332. if constexpr (apply_root) {
  333. return device_sqrt(a);
  334. } else {
  335. return a;
  336. }
  337. }
  338. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  339. return acc;
  340. }
  341. #if defined(__CUDACC__) || defined(__HIPCC__)
  342. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  343. return WARP_SHFL_DOWN(acc, offset);
  344. }
  345. #endif
  346. };
  347. template <typename acc_t, typename data_t>
  348. struct NanSumOps {
  349. inline C10_DEVICE acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const {
  350. return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b});
  351. }
  352. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  353. return a + b;
  354. }
  355. inline C10_DEVICE data_t project(acc_t a) const {
  356. return data_t{a};
  357. }
  358. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  359. return acc;
  360. }
  361. #if defined(__CUDACC__) || defined(__HIPCC__)
  362. inline C10_DEVICE acc_t warp_shfl_down(acc_t data, int offset) const {
  363. return WARP_SHFL_DOWN(data, offset);
  364. }
  365. #endif
  366. };
  367. namespace detail {
  368. template <typename scalar_t>
  369. struct LessOrNan {
  370. C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
  371. // If (a == b), then choose the one with lower idx, else min(a, b)
  372. if (at::_isnan(a)) {
  373. if (at::_isnan(b)) {
  374. return idx_a < idx_b;
  375. }
  376. return true;
  377. }
  378. return (a == b) ? idx_a < idx_b : (a < b);
  379. }
  380. };
  381. template <typename scalar_t>
  382. struct GreaterOrNan {
  383. C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const {
  384. // If (a == b), then choose the one with lower idx, else max(a, b)
  385. if (at::_isnan(a)) {
  386. if (at::_isnan(b)) {
  387. return idx_a < idx_b;
  388. }
  389. return true;
  390. }
  391. return (a == b) ? idx_a < idx_b : (a > b);
  392. }
  393. };
  394. template <typename comp_t>
  395. struct MinMaxReductionOps {
  396. using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
  397. using index_t = int64_t;
  398. using arg_t = detail::pair<scalar_t, index_t>;
  399. static C10_DEVICE arg_t project(arg_t arg) {
  400. return arg;
  401. }
  402. static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
  403. return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
  404. }
  405. static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
  406. return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
  407. }
  408. static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
  409. return {a.first, a.second + base_idx};
  410. }
  411. #if defined(__CUDACC__) || defined(__HIPCC__)
  412. static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
  413. return arg_t(WARP_SHFL_DOWN(arg.first, offset),
  414. WARP_SHFL_DOWN(arg.second, offset));
  415. }
  416. #endif
  417. };
  418. template <typename comp_t>
  419. struct ArgReductionOps : public MinMaxReductionOps<comp_t> {
  420. using typename MinMaxReductionOps<comp_t>::scalar_t;
  421. using typename MinMaxReductionOps<comp_t>::index_t;
  422. using typename MinMaxReductionOps<comp_t>::arg_t;
  423. static C10_DEVICE index_t project(arg_t arg) {
  424. return arg.second;
  425. }
  426. };
  427. } // namespace detail
  428. template <typename scalar_t>
  429. struct ArgMaxOps :
  430. public detail::ArgReductionOps<detail::GreaterOrNan<scalar_t>> {
  431. };
  432. template <typename scalar_t>
  433. struct ArgMinOps :
  434. public detail::ArgReductionOps<detail::LessOrNan<scalar_t>> {
  435. };
  436. template <typename scalar_t>
  437. struct MinOps :
  438. public detail::MinMaxReductionOps<detail::LessOrNan<scalar_t>> {
  439. };
  440. template <typename scalar_t>
  441. struct MaxOps :
  442. public detail::MinMaxReductionOps<detail::GreaterOrNan<scalar_t>> {
  443. };
  444. template <typename scalar_t, typename acc_scalar_t, typename index_t>
  445. struct MinMaxOps {
  446. using acc_t = detail::pair<acc_scalar_t, acc_scalar_t>;
  447. inline C10_DEVICE acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const {
  448. return combine(acc, {data, data});
  449. }
  450. inline C10_DEVICE acc_t combine(acc_t a, acc_t b) const {
  451. auto min_val = (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first;
  452. auto max_val = (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second;
  453. return {min_val, max_val};
  454. }
  455. inline C10_DEVICE acc_t project(acc_t acc) const {
  456. return acc;
  457. }
  458. static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) {
  459. return acc;
  460. }
  461. #if defined(__CUDACC__) || defined(__HIPCC__)
  462. inline C10_DEVICE acc_t warp_shfl_down(acc_t acc, int offset) const {
  463. return {
  464. WARP_SHFL_DOWN(acc.first, offset), WARP_SHFL_DOWN(acc.second, offset)
  465. };
  466. }
  467. #endif
  468. };
  469. } // namespace at::native
  470. #undef MAX
  471. #undef MIN
  472. #else
  473. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  474. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)