ReduceOpsUtils.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <limits>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/native/Resize.h>
  6. #include <ATen/native/TensorIterator.h>
  7. #include <ATen/native/NonEmptyUtils.h>
  8. #include <ATen/WrapDimUtilsMulti.h>
  9. #include <c10/core/ScalarType.h>
  10. #include <c10/util/irange.h>
  11. #ifndef AT_PER_OPERATOR_HEADERS
  12. #include <ATen/Functions.h>
  13. #else
  14. #include <ATen/ops/empty.h>
  15. #include <ATen/ops/scalar_tensor.h>
  16. #endif
  17. namespace at::native {
  18. // Maximum and minimum possible scalar values, including infinities
  19. template <typename scalar_t>
  20. constexpr scalar_t upper_bound() {
  21. using lim = std::numeric_limits<scalar_t>;
  22. return lim::has_infinity ? lim::infinity() : lim::max();
  23. }
  24. template <typename scalar_t>
  25. constexpr scalar_t lower_bound() {
  26. using lim = std::numeric_limits<scalar_t>;
  27. return lim::has_infinity ? -lim::infinity() : lim::lowest();
  28. }
  29. inline Tensor restride_dim(
  30. const Tensor& src, int64_t dim,
  31. IntArrayRef replacement_shape
  32. ) {
  33. auto strides = ensure_nonempty_vec(src.strides().vec());
  34. strides[dim] = 0;
  35. return src.as_strided(replacement_shape, strides);
  36. }
  37. inline void _dimreduce_setup(const Tensor &result, const Tensor &self,
  38. int64_t dim) {
  39. IntArrayRef self_sizes = self.sizes();
  40. std::vector<int64_t> result_sizes;
  41. result_sizes.insert(result_sizes.end(), self_sizes.begin(), self_sizes.end());
  42. result_sizes[dim] = 1;
  43. result.resize_(result_sizes);
  44. }
  45. inline bool _dimreduce_return_trivial(const Tensor &result, const Tensor &self,
  46. const Scalar& ident, int64_t dim, bool keepdim) {
  47. if (self.numel() == 1 && self.ndimension() == 0) {
  48. result.resize_({});
  49. result.fill_(self);
  50. return true;
  51. }
  52. // Return identity
  53. if (self.numel() == 0) {
  54. _dimreduce_setup(result, self, dim);
  55. result.fill_(ident);
  56. if (!keepdim) result.squeeze_(dim);
  57. return true;
  58. }
  59. return false;
  60. }
  61. inline bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &self,
  62. int64_t /*dim*/, bool /*keepdim*/, const char* /*fn_name*/) {
  63. if (self.numel() == 1 && self.ndimension() == 0) {
  64. result.resize_({});
  65. result.fill_(self);
  66. return true;
  67. }
  68. return false;
  69. }
  70. inline std::optional<Tensor> _allreduce_return_trivial(
  71. const Tensor& self,
  72. const Scalar& ident) {
  73. // Return identity
  74. if (self.numel() == 0) {
  75. return at::scalar_tensor(ident, self.options());
  76. }
  77. return std::nullopt;
  78. }
  79. #define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \
  80. { \
  81. TORCH_CHECK(\
  82. out.option() == self.option(),\
  83. "expected ", #option, " ",\
  84. self.option(),\
  85. " but found ", out.option())\
  86. }
  87. inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) {
  88. OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self);
  89. OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options());
  90. OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options());
  91. }
  92. inline Tensor integer_upcast(const Tensor& self, std::optional<ScalarType> dtype) {
  93. ScalarType scalarType = self.scalar_type();
  94. TORCH_CHECK(!isBarebonesUnsignedType(scalarType), "integer upcasting for uint16, uint32 and uint64 is not currently implemented");
  95. ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType);
  96. return self.toType(upcast_scalarType);
  97. }
  98. using DimMask = TensorIterator::DimMask;
  99. inline DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
  100. if (opt_dims.has_value()) {
  101. return DimVector(opt_dims.value());
  102. } else {
  103. std::vector<int64_t> all_dims(ndim);
  104. std::iota(all_dims.begin(), all_dims.end(), 0);
  105. return DimVector(all_dims);
  106. }
  107. }
  108. inline DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim, bool allow_empty_dims=false) {
  109. DimMask mask;
  110. if (opt_dims.has_value()) {
  111. auto dims = opt_dims.value();
  112. if (dims.empty() && !allow_empty_dims) {
  113. mask = DimMask().flip();
  114. } else {
  115. mask = at::dim_list_to_bitset(dims, ndim);
  116. }
  117. } else {
  118. mask = DimMask().flip();
  119. }
  120. return mask;
  121. }
  122. inline DimVector shape_from_dim_mask(const Tensor& self, DimMask mask, bool keepdim) {
  123. auto shape = DimVector(self.sizes());
  124. for (int dim = shape.size() - 1; dim >= 0; dim--) {
  125. if (mask[dim]) {
  126. if (keepdim) {
  127. shape[dim] = 1;
  128. } else {
  129. shape.erase(shape.begin() + dim);
  130. }
  131. }
  132. }
  133. return shape;
  134. }
  135. inline void resize_reduction_result(
  136. Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
  137. ScalarType /*dtype*/)
  138. {
  139. auto shape = shape_from_dim_mask(self, mask, keepdim);
  140. TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
  141. at::native::resize_output(result, shape);
  142. }
  143. inline Tensor create_reduction_result(
  144. const Tensor& self, at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype
  145. ) {
  146. DimMask mask = make_dim_mask(dim, self.dim());
  147. auto shape = shape_from_dim_mask(self, mask, keepdim);
  148. return at::empty(shape, self.options().dtype(dtype));
  149. }
  150. inline Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
  151. if (keepdim) {
  152. return result;
  153. }
  154. auto shape = DimVector(result.sizes());
  155. auto stride = DimVector(result.strides());
  156. for (const auto dim : c10::irange(ndim)) {
  157. if (mask[dim]) {
  158. shape.insert(shape.begin() + dim, 1);
  159. stride.insert(stride.begin() + dim, 0);
  160. }
  161. }
  162. return result.as_strided(shape, stride);
  163. }
  164. inline TensorIterator make_reduction(
  165. const char* name, Tensor& result, const Tensor& self,
  166. at::OptionalIntArrayRef dim_opt,
  167. bool keepdim, ScalarType in_dtype, ScalarType out_dtype) {
  168. // check that result type and dtype match if provided
  169. TORCH_CHECK(
  170. !result.defined() || result.scalar_type() == out_dtype,
  171. name, ": provided dtype must match dtype of result. Got ",
  172. toString(result.scalar_type()),
  173. " and ",
  174. toString(out_dtype),
  175. ".");
  176. // dim={} performs an all-reduce, same as dim=None
  177. IntArrayRef dim = dim_opt.value_or(IntArrayRef{});
  178. int64_t ndim = self.dim();
  179. auto mask = make_dim_mask(dim, ndim);
  180. resize_reduction_result(result, self, mask, keepdim, out_dtype);
  181. auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
  182. namedinference::propagate_names_for_reduction(result, self, dim, keepdim);
  183. if (self.scalar_type() == in_dtype) {
  184. return TensorIterator::reduce_op(viewed_result, self);
  185. }
  186. return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
  187. }
  188. [[maybe_unused]] inline TensorIterator make_reduction(
  189. const char* name,
  190. Tensor& result,
  191. const Tensor& self,
  192. at::OptionalIntArrayRef dim,
  193. bool keepdim,
  194. ScalarType out_dtype) {
  195. // special case for type promotion in mixed precision, improves computational
  196. // efficiency.
  197. // not generalize this to common mismatched input/output types to avoid cross
  198. // product of templated kernel launches.
  199. const bool gpu_lowp_to_f32 = (
  200. (self.is_cuda() || self.is_xpu()) && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat);
  201. auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type()
  202. : self.is_complex() ? c10::toComplexType(out_dtype)
  203. : out_dtype;
  204. return make_reduction(name, result, self, dim, keepdim, in_dtype, out_dtype);
  205. }
  206. inline TensorIterator make_reduction(
  207. const char* name, Tensor& result1, Tensor& result2, const Tensor& self,
  208. at::OptionalIntArrayRef dim_opt, bool keepdim, ScalarType dtype1,
  209. ScalarType dtype2) {
  210. // check that result type and dtype match if provided
  211. TORCH_CHECK(
  212. (!result1.defined() || result1.scalar_type() == dtype1) && (!result2.defined() || result2.scalar_type() == dtype2),
  213. name, ": provided dtype must match dtype of result. Got ",
  214. toString(result1.scalar_type()), toString(result2.scalar_type()),
  215. " and ",
  216. toString(dtype1), toString(dtype2),
  217. ".");
  218. // dim={} performs an all-reduce, same as dim=None
  219. auto dim = dim_opt.value_or(IntArrayRef{});
  220. int64_t ndim = self.dim();
  221. DimMask mask = make_dim_mask(dim, ndim);
  222. resize_reduction_result(result1, self, mask, keepdim, dtype1);
  223. auto viewed_result1 = review_reduce_result(result1, ndim, mask, keepdim);
  224. resize_reduction_result(result2, self, mask, keepdim, dtype2);
  225. auto viewed_result2 = review_reduce_result(result2, ndim, mask, keepdim);
  226. namedinference::propagate_names_for_reduction(result1, self, dim, keepdim);
  227. namedinference::propagate_names_for_reduction(result2, self, dim, keepdim);
  228. // special case for type promotion in mixed precision, improves computational
  229. // efficiency.
  230. // We don't generalize this to common mismatched input/output types to avoid cross
  231. // product of templated kernel launches.
  232. if (self.scalar_type() == dtype1 ||
  233. (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
  234. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
  235. }
  236. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
  237. }
  238. [[maybe_unused]] inline TensorIterator make_reduction(
  239. const char* name,
  240. Tensor& result1,
  241. Tensor& result2,
  242. const Tensor& self,
  243. at::OptionalIntArrayRef dim,
  244. bool keepdim,
  245. ScalarType dtype) {
  246. return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype);
  247. }
  248. inline void zero_numel_check_dims(const Tensor& self, const int64_t dim, const char *fn_name) {
  249. if (self.ndimension() == 0) {
  250. TORCH_CHECK_INDEX(dim == 0 || dim == -1, fn_name,
  251. ": Expected reduction dim -1 or 0 for scalar but got ", dim);
  252. }
  253. else {
  254. TORCH_CHECK_INDEX(self.size(dim) != 0, fn_name,
  255. ": Expected reduction dim ", dim, " to have non-zero size.");
  256. }
  257. }
  258. inline void zero_numel_check_dims(const Tensor& self, const IntArrayRef dim, const char *fn_name) {
  259. TORCH_CHECK(
  260. !dim.empty(),
  261. fn_name, ": Expected reduction dim to be specified for input.numel() == 0. ",
  262. "Specify the reduction dim with the 'dim' argument.");
  263. for (const int64_t d : dim) {
  264. zero_numel_check_dims(self, d, fn_name);
  265. }
  266. }
  267. inline std::vector<int64_t> get_zero_numel_tensor_size(
  268. const Tensor& self,
  269. const int64_t dim,
  270. const bool keepdim,
  271. const char* fn_name) {
  272. TORCH_INTERNAL_ASSERT(self.numel() == 0, fn_name, ": Expected self.numel() == 0.");
  273. zero_numel_check_dims(self, dim, fn_name);
  274. std::vector<int64_t> sizes;
  275. if (keepdim) {
  276. sizes = self.sizes().vec();
  277. sizes[dim] = 1;
  278. }
  279. else {
  280. for (const auto d : c10::irange(self.dim())) {
  281. if (d != dim) {
  282. sizes.push_back(self.sizes()[d]);
  283. }
  284. }
  285. }
  286. return sizes;
  287. }
  288. // Resize the result tensor and indices when result.numel() == 0 depending on values of
  289. // dim and keepdim for returning tensors containing reduction results.
  290. // This function should be called when you are reducing a zero-numel tensor and want to
  291. // resize the output and return it. This function exists for resizing zero-numel
  292. // tensors when the size of the reduction dimension is non-zero.
  293. [[maybe_unused]] inline void zero_numel_tensor_resize(
  294. Tensor& result,
  295. Tensor& result_indices,
  296. const Tensor& self,
  297. const int64_t dim,
  298. const bool keepdim,
  299. const char* fn_name) {
  300. auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name);
  301. at::native::resize_output(result, sizes);
  302. at::native::resize_output(result_indices, sizes);
  303. }
  304. inline ScalarType get_dtype_from_self(
  305. const Tensor& self,
  306. const std::optional<ScalarType>& dtype,
  307. bool promote_integers) {
  308. if (dtype.has_value()) {
  309. return dtype.value();
  310. }
  311. ScalarType src_type = self.scalar_type();
  312. if (promote_integers && at::isIntegralType(src_type, /*includeBool=*/true)) {
  313. return kLong;
  314. }
  315. return src_type;
  316. }
  317. inline ScalarType get_dtype_from_result(Tensor& result, std::optional<ScalarType> dtype) {
  318. TORCH_CHECK(result.defined(), "Cannot create a new tensor inside a reduction op. You likely tried to call an operator with an out argument but the out argument was an undefined tensor.");
  319. if (dtype.has_value()) {
  320. return dtype.value();
  321. } else {
  322. return result.scalar_type();
  323. }
  324. }
  325. } // namespace at::native
  326. namespace at::meta {
  327. [[maybe_unused]] inline DimVector get_reduction_shape(
  328. const Tensor& self,
  329. IntArrayRef dims,
  330. bool keepdim,
  331. bool allow_empty_dims = false) {
  332. auto mask = native::make_dim_mask(dims, self.dim(), allow_empty_dims);
  333. return native::shape_from_dim_mask(self, mask, keepdim);
  334. }
  335. inline void resize_reduction(
  336. impl::MetaBase& meta,
  337. const Tensor& self,
  338. OptionalIntArrayRef opt_dims,
  339. bool keepdim,
  340. ScalarType out_dtype,
  341. bool allow_empty_dims=false) {
  342. DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
  343. maybe_wrap_dims(dims_, self.dim());
  344. auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
  345. if (self.layout() == kStrided) {
  346. meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
  347. } else if (shape.empty()) {
  348. meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided));
  349. } else {
  350. TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet");
  351. }
  352. namedinference::propagate_names_for_reduction(
  353. meta.maybe_get_output(), self, dims_, keepdim);
  354. }
  355. inline void resize_reduction_with_indices(
  356. impl::MetaBase& meta,
  357. const Tensor& self,
  358. IntArrayRef dims,
  359. bool keepdim,
  360. ScalarType out_dtype) {
  361. DimVector dims_(dims);
  362. maybe_wrap_dims(dims_, self.dim());
  363. auto shape = get_reduction_shape(self, dims_, keepdim);
  364. meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
  365. meta.set_output_raw_strided(1, shape, {}, self.options().dtype(kLong));
  366. namedinference::propagate_names_for_reduction(
  367. meta.maybe_get_output(0), self, dims_, keepdim);
  368. namedinference::propagate_names_for_reduction(
  369. meta.maybe_get_output(1), self, dims_, keepdim);
  370. }
  371. inline TensorIterator make_reduction(
  372. const Tensor& self,
  373. const Tensor& result,
  374. OptionalIntArrayRef opt_dims,
  375. bool keepdim,
  376. ScalarType in_dtype) {
  377. int64_t ndim = self.dim();
  378. auto mask = at::native::make_dim_mask(opt_dims, ndim);
  379. auto viewed_result =
  380. at::native::review_reduce_result(result, ndim, mask, keepdim);
  381. if (self.scalar_type() == in_dtype) {
  382. return TensorIterator::reduce_op(viewed_result, self);
  383. }
  384. return TensorIterator::reduce_op(viewed_result, self.to(in_dtype));
  385. }
  386. inline TensorIterator make_reduction(
  387. const Tensor& self,
  388. const Tensor& result1,
  389. const Tensor& result2,
  390. IntArrayRef dims,
  391. bool keepdim,
  392. ScalarType dtype1,
  393. ScalarType /*dtype2*/) {
  394. int64_t ndim = self.dim();
  395. auto mask = at::native::make_dim_mask(dims, ndim);
  396. auto viewed_result1 = at::native::review_reduce_result(result1, ndim, mask, keepdim);
  397. auto viewed_result2 = at::native::review_reduce_result(result2, ndim, mask, keepdim);
  398. // special case for type promotion in mixed precision, improves computational efficiency.
  399. // We don't generalize this to common mismatched input/output types to avoid cross product
  400. // of templated kernel launches.
  401. if (self.scalar_type() == dtype1 ||
  402. (self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
  403. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
  404. }
  405. return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
  406. }
  407. [[maybe_unused]] inline TensorIterator make_reduction_from_out_ty(
  408. const Tensor& self,
  409. const Tensor& result,
  410. OptionalIntArrayRef opt_dims,
  411. bool keepdim,
  412. ScalarType out_dtype) {
  413. // special case for type promotion in mixed precision, improves computational
  414. // efficiency.
  415. // not generalize this to common mismatched input/output types to avoid cross
  416. // product of templated kernel launches.
  417. const bool gpu_lowp_to_f32 =
  418. (self.is_cuda() &&
  419. (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
  420. out_dtype == kFloat);
  421. auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
  422. return make_reduction(self, result, opt_dims, keepdim, in_dtype);
  423. }
  424. } // namespace at::meta
  425. #else
  426. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  427. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)