RangeUtils.h 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <ATen/AccumulateType.h>
  3. #include <c10/core/Scalar.h>
  4. #include <limits>
  5. namespace at::native {
  6. inline void arange_check_bounds(
  7. const c10::Scalar& start,
  8. const c10::Scalar& end,
  9. const c10::Scalar& step) {
  10. // use double precision for validation to avoid precision issues
  11. double dstart = start.to<double>();
  12. double dend = end.to<double>();
  13. double dstep = step.to<double>();
  14. TORCH_CHECK(dstep > 0 || dstep < 0, "step must be nonzero");
  15. TORCH_CHECK(
  16. std::isfinite(dstart) && std::isfinite(dend),
  17. "unsupported range: ",
  18. dstart,
  19. " -> ",
  20. dend);
  21. TORCH_CHECK(
  22. ((dstep > 0) && (dend >= dstart)) || ((dstep < 0) && (dend <= dstart)),
  23. "upper bound and lower bound inconsistent with step sign");
  24. }
  25. template <typename scalar_t>
  26. int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) {
  27. arange_check_bounds(start, end, step);
  28. // we use double precision for (start - end) / step
  29. // to compute size_d for consistency across devices.
  30. // The problem with using accscalar_t is that accscalar_t might be float32 on gpu for a float32 scalar_t,
  31. // but double on cpu for the same,
  32. // and the effective output size starts differing on CPU vs GPU because of precision issues, which
  33. // we dont want.
  34. // the corner-case we do want to take into account is int64_t, which has higher precision than double
  35. double size_d;
  36. if constexpr (std::is_same_v<scalar_t, int64_t>) {
  37. using accscalar_t = at::acc_type<scalar_t, false>;
  38. auto xstart = start.to<accscalar_t>();
  39. auto xend = end.to<accscalar_t>();
  40. auto xstep = step.to<accscalar_t>();
  41. int64_t sgn = (xstep > 0) - (xstep < 0);
  42. size_d = std::ceil((xend - xstart + xstep - sgn) / xstep);
  43. } else {
  44. size_d = std::ceil((end.to<double>() - start.to<double>())
  45. / step.to<double>());
  46. }
  47. TORCH_CHECK(size_d >= 0 && size_d <= static_cast<double>(std::numeric_limits<int64_t>::max()),
  48. "invalid size, possible overflow?");
  49. return static_cast<int64_t>(size_d);
  50. }
  51. } // namespace at::native
  52. #else
  53. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  54. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)