irange.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright 2004-present Facebook. All Rights Reserved.
  3. #pragma once
  4. #include <c10/util/TypeSafeSignMath.h>
  5. #include <algorithm>
  6. #include <cstddef>
  7. #include <iterator>
  8. #include <type_traits>
  9. namespace c10 {
  10. namespace detail {
  11. template <
  12. typename I,
  13. bool one_sided = false,
  14. std::enable_if_t<std::is_integral_v<I>, int> = 0>
  15. struct integer_iterator {
  16. using iterator_category = std::input_iterator_tag;
  17. using value_type = I;
  18. using difference_type = std::ptrdiff_t;
  19. using pointer = I*;
  20. using reference = I&;
  21. explicit constexpr integer_iterator(I val) : value(val) {}
  22. constexpr I operator*() const {
  23. return value;
  24. }
  25. constexpr I const* operator->() const {
  26. return &value;
  27. }
  28. constexpr integer_iterator& operator++() {
  29. ++value;
  30. return *this;
  31. }
  32. constexpr integer_iterator operator++(int) {
  33. const auto copy = *this;
  34. ++*this;
  35. return copy;
  36. }
  37. constexpr bool operator==(const integer_iterator& other) const {
  38. if constexpr (one_sided) {
  39. // Range-for loops' end test is `begin != end`, not `begin <
  40. // end`. To handle `c10::irange(n)` where n < 0 (which should be
  41. // empty), we just make `begin != end` fail whenever `end` is
  42. // negative.
  43. return is_negative(other.value) || value == other.value;
  44. } else {
  45. return value == other.value;
  46. }
  47. // Suppress "warning: missing return statement at end of non-void function"
  48. // which Nvidia's Robert Crovella confirms is an NVCC compiler error
  49. // here https://stackoverflow.com/a/64561686/752843 on 2020-10-27
  50. // `__builtin_unreachable();` would be best here, but it's not
  51. // available with all compilers. So we instead return an arbitrary
  52. // value trusting that this line will, in fact, never be reached.
  53. return false; // Horrible hack
  54. }
  55. constexpr bool operator!=(const integer_iterator& other) const {
  56. return !(*this == other);
  57. }
  58. protected:
  59. I value;
  60. };
  61. } // namespace detail
  62. template <
  63. typename I,
  64. bool one_sided = false,
  65. std::enable_if_t<std::is_integral_v<I>, bool> = true>
  66. struct integer_range {
  67. public:
  68. constexpr integer_range(I begin, I end) : begin_(begin), end_(end) {}
  69. using iterator = detail::integer_iterator<I, one_sided>;
  70. constexpr iterator begin() const {
  71. return begin_;
  72. }
  73. constexpr iterator end() const {
  74. return end_;
  75. }
  76. private:
  77. iterator begin_;
  78. iterator end_;
  79. };
  80. /// Creates an integer range for the half-open interval [begin, end)
  81. /// If end<=begin, then the range is empty.
  82. /// The range has the type of the `end` integer; `begin` integer is
  83. /// cast to this type.
  84. template <
  85. typename Integer1,
  86. typename Integer2,
  87. std::enable_if_t<std::is_integral_v<Integer1>, bool> = true,
  88. std::enable_if_t<std::is_integral_v<Integer2>, bool> = true>
  89. constexpr integer_range<Integer2> irange(Integer1 begin, Integer2 end) {
  90. // If end<=begin then the range is empty; we can achieve this effect by
  91. // choosing the larger of {begin, end} as the loop terminator
  92. return {
  93. static_cast<Integer2>(begin),
  94. std::max(static_cast<Integer2>(begin), end)};
  95. }
  96. /// Creates an integer range for the half-open interval [0, end)
  97. /// If end<=begin, then the range is empty
  98. template <
  99. typename Integer,
  100. std::enable_if_t<std::is_integral_v<Integer>, bool> = true>
  101. constexpr integer_range<Integer, true> irange(Integer end) {
  102. return {Integer(), end};
  103. }
  104. } // namespace c10
  105. #else
  106. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  107. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)