DistributionTemplates.h 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/Tensor.h>
  4. #include <ATen/Dispatch.h>
  5. #include <ATen/Dispatch_v2.h>
  6. #include <ATen/Generator.h>
  7. #include <ATen/ExpandUtils.h>
  8. #include <ATen/Tensor.h>
  9. #include <ATen/MemoryOverlap.h>
  10. #include <ATen/NamedTensorUtils.h>
  11. #include <ATen/native/Resize.h>
  12. #include <ATen/native/TensorIterator.h>
  13. #include <cmath>
  14. #include <limits>
  15. #include <optional>
  16. #ifndef AT_PER_OPERATOR_HEADERS
  17. #include <ATen/Functions.h>
  18. #else
  19. #include <ATen/ops/empty_like.h>
  20. #include <ATen/ops/empty.h>
  21. #include <ATen/ops/full.h>
  22. #include <ATen/ops/view_as_real.h>
  23. #endif
  24. namespace at::native::templates {
  25. // ==================================================== Random ========================================================
  26. // The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can be used as actual `from`.
  27. // The current implementation of `random_` uses uint64_t arithmetic and casts the result to the target dtype(scalar_t).
  28. // This casting can result in generating numbers that happen to be greater or equal to `to` value. For instance:
  29. //
  30. // auto actual = torch::empty({3, 3}, torch::half);
  31. // actual.random_(0, 65504);
  32. //
  33. // If random's uint64_t arithmetic produces 65503 as a random value after casting to torch::half it becomes 65504
  34. // and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
  35. // moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
  36. // the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
  37. // available number for torch::half dtype.
  38. template<typename scalar_t>
  39. int64_t update_from(int64_t from) {
  40. static_assert(
  41. std::is_floating_point_v<scalar_t> ||
  42. std::is_same_v<scalar_t, at::Half> ||
  43. std::is_same_v<scalar_t, at::BFloat16>, "scalar_t must be floating-point type");
  44. const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
  45. if (from_plus_1 < from) {
  46. int64_t from_ = std::abs(from + 1);
  47. int n = 0;
  48. while (from_ >>= 1) ++n;
  49. // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
  50. from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
  51. }
  52. return from;
  53. }
  54. template<typename scalar_t>
  55. int64_t update_to(int64_t to) {
  56. static_assert(
  57. std::is_floating_point_v<scalar_t> ||
  58. std::is_same_v<scalar_t, at::Half> ||
  59. std::is_same_v<scalar_t, at::BFloat16>, "scalar_t must be floating-point type");
  60. const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
  61. if (to_minus_1 >= to) {
  62. int64_t to_ = std::abs(to - 1);
  63. int n = 0;
  64. while (to_ >>= 1) ++n;
  65. // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
  66. to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
  67. }
  68. return to;
  69. }
  70. // Return earlier for not invoking kernel.
  71. // See https://github.com/pytorch/pytorch/issues/103418 for more details
  72. #define CHECK_EMPTY_AND_RETURN(tensor) \
  73. if (tensor.numel() == 0) { \
  74. return tensor; \
  75. }
  76. template<template<typename> class random_kernel, typename RNG>
  77. at::Tensor& random_impl(at::Tensor& self, std::optional<Generator> generator) {
  78. CHECK_EMPTY_AND_RETURN(self);
  79. auto iter = at::TensorIterator::borrowing_nullary_op(self);
  80. random_kernel<RNG>()(iter, generator);
  81. return self;
  82. }
  83. #define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
  84. TORCH_CHECK(var >= min && var <= max, name , " is out of bounds for ", dtype); \
  85. #define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
  86. if (var < -(1LL << digits) || var > (1LL << digits)) { \
  87. TORCH_WARN(name , " is out of bounds [-(2^", digits, "), 2^", digits, "]. ", \
  88. "Due to precision limitations ", dtype, " can support discrete uniform distribution only within this range. ", \
  89. "This warning will become an error in version 1.7 release, please fix the code in advance"); \
  90. }
  91. inline void check_from_to_in_range(int64_t from, int64_t to_inc, caffe2::TypeMeta dtype) {
  92. const auto scalar_type = typeMetaToScalarType(dtype);
  93. if (isFloatingType(scalar_type)) {
  94. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "check_random_fp_bounds", [&] {
  95. const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
  96. const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
  97. CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
  98. CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
  99. constexpr auto digits = std::numeric_limits<scalar_t>::digits;
  100. WARN_OUT_OF_BOUNDS(from, "from", digits, dtype);
  101. WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, dtype);
  102. });
  103. } else if (scalar_type == kUInt64) {
  104. // When you do a comparison between int64_t and uint64_t, the usual
  105. // arithmetic conversions say that the int64_t value is promoted to
  106. // unsigned. But this conversion wraps around: if I had -1 as my int64_t,
  107. // then it will promote to 0xFFFFFFFFFFFFFFFF in uint64_t. This is never
  108. // the right thing to do.
  109. CHECK_OUT_OF_BOUNDS(from, "from", 0, INT64_MAX, dtype);
  110. CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", 0, INT64_MAX, dtype);
  111. } else if (isIntegralType(scalar_type, /*includeBool=*/true)) {
  112. AT_DISPATCH_V2(scalar_type, "check_random_integral_bounds", AT_WRAP([&]() {
  113. const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
  114. const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
  115. CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
  116. CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, dtype);
  117. }), AT_EXPAND(AT_INTEGRAL_TYPES), kUInt16, kUInt32, kBool);
  118. } else {
  119. TORCH_CHECK(false, "check_random_bounds handles only integral, floating-point and boolean types");
  120. }
  121. }
  122. template<template<typename> class random_from_to_kernel, typename RNG>
  123. at::Tensor& random_from_to_impl(at::Tensor& self, int64_t from, std::optional<int64_t> to_opt, std::optional<Generator> generator) {
  124. uint64_t range = 0;
  125. auto iter = at::TensorIterator::borrowing_nullary_op(self);
  126. if (to_opt.has_value()) {
  127. // [from, to)
  128. int64_t to = *to_opt;
  129. TORCH_CHECK(from < to, "random_ expects 'from' to be less than 'to', but got from=", from, " >= to=", to);
  130. if (isFloatingType(iter.dtype())) {
  131. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_update_from_to", [&] {
  132. from = update_from<scalar_t>(from);
  133. to = update_to<scalar_t>(to);
  134. TORCH_CHECK(from < to, "random_ expects 'from' casted to dtype to be less than 'to' casted to dtype, but got from=", from, " >= to=", to);
  135. });
  136. }
  137. check_from_to_in_range(from, to - 1, self.dtype());
  138. CHECK_EMPTY_AND_RETURN(self);
  139. range = static_cast<uint64_t>(to) - static_cast<uint64_t>(from);
  140. random_from_to_kernel<RNG>()(iter, range, from, generator);
  141. } else if (from != std::numeric_limits<int64_t>::lowest()) {
  142. // [from, std::numeric_limits<int64_t>::max()]
  143. int64_t to_inc = 0;
  144. if (isFloatingType(iter.dtype())) {
  145. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "random_from_to_range_calc", [&] {
  146. constexpr int64_t scalar_t_max = static_cast<int64_t>(1) << std::numeric_limits<scalar_t>::digits;
  147. to_inc = scalar_t_max > std::numeric_limits<int64_t>::max() ? std::numeric_limits<int64_t>::max() : static_cast<int64_t>(scalar_t_max);
  148. from = update_from<scalar_t>(from);
  149. TORCH_CHECK(from < to_inc, "random_ expects 'from' casted to dtype to be less than or equal to 'to_inc' casted to dtype, but got from=", from, " > to_inc=", to_inc);
  150. });
  151. } else if (isIntegralType(iter.dtype(), /*includeBool=*/true)) {
  152. AT_DISPATCH_V2(self.scalar_type(), "random_from_to_range_calc", AT_WRAP([&] {
  153. if constexpr (std::is_same_v<scalar_t, bool>) {
  154. to_inc = static_cast<int64_t>(true);
  155. } else {
  156. to_inc = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
  157. }
  158. }), AT_EXPAND(AT_INTEGRAL_TYPES_V2), kBool);
  159. } else {
  160. TORCH_CHECK(false, "random_from_to_impl handles only integral, floating-point and boolean types");
  161. }
  162. check_from_to_in_range(from, to_inc, self.dtype());
  163. CHECK_EMPTY_AND_RETURN(self);
  164. range = static_cast<uint64_t>(to_inc) - static_cast<uint64_t>(from) + 1;
  165. random_from_to_kernel<RNG>()(iter, range, from, generator);
  166. } else {
  167. // [std::numeric_limits<int64_t>::lowest(), std::numeric_limits<int64_t>::max()]
  168. // range = 2^64
  169. CHECK_EMPTY_AND_RETURN(self);
  170. random_from_to_kernel<RNG>()(iter, generator);
  171. }
  172. return self;
  173. }
  174. // ==================================================== Normal ========================================================
  175. #define CHECK_NORMAL_TENSOR_STD(std) \
  176. do { \
  177. TORCH_CHECK( \
  178. !std.is_complex(), \
  179. "normal expects standard deviation to be non-complex"); \
  180. TORCH_CHECK( \
  181. std.numel() == 0 || std.is_meta() || std.min().ge(0).item<bool>(), \
  182. "normal expects all elements of std >= 0.0"); \
  183. } while (0)
  184. #define CHECK_NORMAL_STD(std) \
  185. TORCH_CHECK(std >= 0.0, "normal expects std >= 0.0, but found std ", std);
  186. template<template<typename> class normal_kernel, typename RNG>
  187. Tensor& normal_impl_(Tensor& self, double mean, double std, std::optional<Generator> gen) {
  188. CHECK_NORMAL_STD(std);
  189. CHECK_EMPTY_AND_RETURN(self);
  190. if (self.is_complex()) {
  191. auto float_tensor = at::view_as_real(self);
  192. // variance for normal distribution of the real and imaginary values
  193. // is half of the input variance
  194. normal_kernel<RNG>()(float_tensor, mean, std/(std::sqrt(2)), gen);
  195. } else {
  196. normal_kernel<RNG>()(self, mean, std, gen);
  197. }
  198. return self;
  199. }
  200. template<template<typename> class normal_kernel, typename RNG>
  201. Tensor& normal_out_impl(Tensor& output, const Tensor& mean, double std, std::optional<Generator> gen) {
  202. CHECK_NORMAL_STD(std);
  203. auto std_tensor = at::empty_like(output, MemoryFormat::Contiguous);
  204. auto shape = at::infer_size(mean.sizes(), std_tensor.sizes());
  205. at::native::resize_output(output, shape);
  206. normal_impl_<normal_kernel, RNG>(output, 0, std, gen);
  207. output.add_(mean);
  208. return output;
  209. }
  210. template<template<typename> class normal_kernel, typename RNG>
  211. Tensor& normal_out_impl(Tensor& output, double mean, const Tensor& std, std::optional<Generator> gen) {
  212. CHECK_NORMAL_TENSOR_STD(std);
  213. auto mean_tensor = at::full({}, mean, output.options());
  214. auto shape = at::infer_size(mean_tensor.sizes(), std.sizes());
  215. at::native::resize_output(output, shape);
  216. normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
  217. // CUDA NB: addcmul_out copies the tensor to be added into the output.
  218. // The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
  219. // The third argument is not a constant reference and hence the samples in output are overwritten.
  220. // Consequently, the computation performed is mean_tensor + mean_tensor * std instead of mean_tensor + output * std
  221. output.mul_(std).add_(mean_tensor);
  222. return output;
  223. }
  224. template<template<typename> class normal_kernel, typename RNG>
  225. Tensor& normal_out_impl(Tensor& output, const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
  226. CHECK_NORMAL_TENSOR_STD(std);
  227. auto shape = at::infer_size(mean.sizes(), std.sizes());
  228. at::native::resize_output(output, shape);
  229. normal_impl_<normal_kernel, RNG>(output, 0, 1, gen);
  230. // CUDA NB: addcmul_out copies the tensor to be added into the output.
  231. // The previous function here was addcmul_out(output, mean, output, std, 1);
  232. // The third argument is not a constant reference and hence the samples in output are overwritten.
  233. // Consequently, the computation performed is mean + mean * std instead of mean + output * std
  234. output.mul_(std).add_(mean);
  235. return output;
  236. }
  237. template<template<typename> class normal_kernel, typename RNG>
  238. Tensor normal_impl(const Tensor& mean, double std, std::optional<Generator> gen) {
  239. CHECK_NORMAL_STD(std);
  240. Tensor ret = at::empty_like(mean, MemoryFormat::Contiguous);
  241. normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
  242. return ret;
  243. }
  244. template<template<typename> class normal_kernel, typename RNG>
  245. Tensor normal_impl(double mean, const Tensor& std, std::optional<Generator> gen) {
  246. CHECK_NORMAL_TENSOR_STD(std);
  247. Tensor ret = at::empty_like(std, MemoryFormat::Contiguous);
  248. normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
  249. return ret;
  250. }
  251. template<template<typename> class normal_kernel, typename RNG>
  252. Tensor normal_impl(const Tensor& mean, const Tensor& std, std::optional<Generator> gen) {
  253. CHECK_NORMAL_TENSOR_STD(std);
  254. auto shape = at::infer_size(mean.sizes(), std.sizes());
  255. Tensor ret = at::empty(shape, mean.options(), MemoryFormat::Contiguous);
  256. normal_out_impl<normal_kernel, RNG>(ret, mean, std, gen);
  257. return ret;
  258. }
  259. // ==================================================== Uniform =======================================================
  260. template<template<typename> class uniform_kernel, typename RNG>
  261. at::Tensor& uniform_impl_(at::Tensor& self, double from, double to, std::optional<Generator> generator) {
  262. if (self.is_complex()) {
  263. CHECK_EMPTY_AND_RETURN(self);
  264. auto float_tensor = at::view_as_real(self);
  265. uniform_impl_<uniform_kernel, RNG>(float_tensor, from, to, generator);
  266. } else {
  267. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "check_uniform_bounds", [&] {
  268. [[maybe_unused]] const auto dtype = self.dtype();
  269. const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
  270. const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
  271. CHECK_OUT_OF_BOUNDS(from, "from", min, max, dtype);
  272. CHECK_OUT_OF_BOUNDS(to, "to", min, max, dtype);
  273. TORCH_CHECK(from <= to, "uniform_ expects to return a [from, to) range, but found from=", from, " > to=", to);
  274. TORCH_CHECK((to - from) <= std::numeric_limits<scalar_t>::max(),
  275. "uniform_ expects to-from <= std::numeric_limits<", toString(self.scalar_type()),
  276. ">::max(), but found to=", to, " and from=", from,
  277. " which result in to-from to exceed the limit");
  278. from = std::min(std::max(from, min), max);
  279. to = std::max(std::min(to, max), min);
  280. });
  281. CHECK_EMPTY_AND_RETURN(self);
  282. auto iter = at::TensorIterator::borrowing_nullary_op(self);
  283. uniform_kernel<RNG>()(iter, from, to, generator);
  284. }
  285. return self;
  286. }
  287. // ================================================== LogNormal =======================================================
  288. template<template<typename> class log_normal_kernel, typename RNG>
  289. at::Tensor& log_normal_impl_(at::Tensor& self, double mean, double std, std::optional<Generator> gen) {
  290. TORCH_CHECK(std > 0.0, "log_normal_ expects std > 0.0, but found std=", std);
  291. CHECK_EMPTY_AND_RETURN(self);
  292. auto iter = TensorIterator::borrowing_nullary_op(self);
  293. log_normal_kernel<RNG>()(iter, mean, std, gen);
  294. return self;
  295. }
  296. // =================================================== Geometric ======================================================
  297. template<template<typename> class geometric_kernel, typename RNG>
  298. Tensor& geometric_impl_(Tensor& self, double p, std::optional<Generator> gen) {
  299. TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
  300. CHECK_EMPTY_AND_RETURN(self);
  301. auto iter = TensorIterator::borrowing_nullary_op(self);
  302. geometric_kernel<RNG>()(iter, p, gen);
  303. return self;
  304. }
  305. // ================================================== Exponential =====================================================
  306. template<template<typename> class exponential_kernel, typename RNG>
  307. Tensor& exponential_impl_(Tensor& self, double lambda, std::optional<Generator> gen) {
  308. TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda);
  309. CHECK_EMPTY_AND_RETURN(self);
  310. auto iter = TensorIterator::borrowing_nullary_op(self);
  311. exponential_kernel<RNG>()(iter, lambda, gen);
  312. return self;
  313. }
  314. // ==================================================== Cauchy ========================================================
  315. template<template<typename> class cauchy_kernel, typename RNG>
  316. Tensor& cauchy_impl_(Tensor& self, double median, double sigma, std::optional<Generator> gen) {
  317. // TODO: instead of variable name 'sigma', use 'gamma' or 'scale'
  318. // the variance, squared sigma, is undefined for cauchy distribution
  319. TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma);
  320. TORCH_CHECK(at::isFloatingType(self.scalar_type()), "Cauchy distribution is a continuous probability distribution. dtype must be a floating point but you specified ", self.dtype());
  321. CHECK_EMPTY_AND_RETURN(self);
  322. auto iter = TensorIterator::borrowing_nullary_op(self);
  323. cauchy_kernel<RNG>()(iter, median, sigma, gen);
  324. return self;
  325. }
  326. // ==================================================== Bernoulli =====================================================
  327. template<template<typename> class bernoulli_tensor_kernel, typename RNG>
  328. Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, std::optional<Generator> gen) {
  329. CHECK_EMPTY_AND_RETURN(self);
  330. NoNamesGuard guard;
  331. at::assert_no_internal_overlap(self);
  332. bernoulli_tensor_kernel<RNG>()(self, p_, gen);
  333. return self;
  334. }
  335. template<template<typename> class bernoulli_scalar_kernel, typename RNG>
  336. Tensor& bernoulli_impl_(Tensor& self, double p, std::optional<Generator> gen) {
  337. TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
  338. CHECK_EMPTY_AND_RETURN(self);
  339. at::assert_no_internal_overlap(self);
  340. bernoulli_scalar_kernel<RNG>()(self, p, gen);
  341. return self;
  342. }
  343. template<template<typename> class bernoulli_tensor_kernel, typename RNG>
  344. Tensor& bernoulli_out_impl(Tensor& result, const Tensor& self, std::optional<Generator> gen) {
  345. // result.resize_as_(self) requires self to have same dtype as result, so we
  346. // use resize_ instead.
  347. // TODO: Fix resize_as_. See pytorch/pytorch#11665.
  348. result.resize_(self.sizes());
  349. bernoulli_impl_<bernoulli_tensor_kernel, RNG>(result, self, gen);
  350. namedinference::propagate_names(result, self);
  351. return result;
  352. }
  353. #undef CHECK_OUT_OF_BOUNDS
  354. #undef WARN_OUT_OF_BOUNDS
  355. } // namespace at::native::templates
  356. #else
  357. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  358. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)