Enumerate.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*
  3. * Ported from folly/container/Enumerate.h
  4. */
  5. #pragma once
  6. #include <iterator>
  7. #include <memory>
  8. #ifdef _WIN32
  9. #include <basetsd.h> // @manual
  10. using ssize_t = SSIZE_T;
  11. #endif
  12. #include <c10/macros/Macros.h>
  13. /**
  14. * Similar to Python's enumerate(), enumerate() can be used to
  15. * iterate a range with a for-range loop, and it also allows to
  16. * retrieve the count of iterations so far. Can be used in constexpr
  17. * context.
  18. *
  19. * For example:
  20. *
  21. * for (auto&& [index, element] : enumerate(vec)) {
  22. * // index is a const reference to a size_t containing the iteration count.
  23. * // element is a reference to the type contained within vec, mutable
  24. * // unless vec is const.
  25. * }
  26. *
  27. * If the binding is const, the element reference is too.
  28. *
  29. * for (const auto&& [index, element] : enumerate(vec)) {
  30. * // element is always a const reference.
  31. * }
  32. *
  33. * It can also be used as follows:
  34. *
  35. * for (auto&& it : enumerate(vec)) {
  36. * // *it is a reference to the current element. Mutable unless vec is const.
  37. * // it->member can be used as well.
  38. * // it.index contains the iteration count.
  39. * }
  40. *
  41. * As before, const auto&& it can also be used.
  42. */
  43. namespace c10 {
  44. namespace detail {
  45. template <class T>
  46. struct MakeConst {
  47. using type = const T;
  48. };
  49. template <class T>
  50. struct MakeConst<T&> {
  51. using type = const T&;
  52. };
  53. template <class T>
  54. struct MakeConst<T*> {
  55. using type = const T*;
  56. };
  57. template <class Iterator>
  58. class Enumerator {
  59. public:
  60. constexpr explicit Enumerator(Iterator it) : it_(std::move(it)) {}
  61. class Proxy {
  62. public:
  63. using difference_type = ssize_t;
  64. using value_type = typename std::iterator_traits<Iterator>::value_type;
  65. using reference = typename std::iterator_traits<Iterator>::reference;
  66. using pointer = typename std::iterator_traits<Iterator>::pointer;
  67. using iterator_category = std::input_iterator_tag;
  68. C10_ALWAYS_INLINE constexpr explicit Proxy(const Enumerator& e)
  69. : index(e.idx_), element(*e.it_) {}
  70. // Non-const Proxy: Forward constness from Iterator.
  71. C10_ALWAYS_INLINE constexpr reference operator*() {
  72. return element;
  73. }
  74. C10_ALWAYS_INLINE constexpr pointer operator->() {
  75. return std::addressof(element);
  76. }
  77. // Const Proxy: Force const references.
  78. C10_ALWAYS_INLINE constexpr typename MakeConst<reference>::type operator*()
  79. const {
  80. return element;
  81. }
  82. C10_ALWAYS_INLINE constexpr typename MakeConst<pointer>::type operator->()
  83. const {
  84. return std::addressof(element);
  85. }
  86. public:
  87. size_t index;
  88. reference element;
  89. };
  90. C10_ALWAYS_INLINE constexpr Proxy operator*() const {
  91. return Proxy(*this);
  92. }
  93. C10_ALWAYS_INLINE constexpr Enumerator& operator++() {
  94. ++it_;
  95. ++idx_;
  96. return *this;
  97. }
  98. template <typename OtherIterator>
  99. C10_ALWAYS_INLINE constexpr bool operator==(
  100. const Enumerator<OtherIterator>& rhs) const {
  101. return it_ == rhs.it_;
  102. }
  103. template <typename OtherIterator>
  104. C10_ALWAYS_INLINE constexpr bool operator!=(
  105. const Enumerator<OtherIterator>& rhs) const {
  106. return !(it_ == rhs.it_);
  107. }
  108. private:
  109. template <typename OtherIterator>
  110. friend class Enumerator;
  111. Iterator it_;
  112. size_t idx_ = 0;
  113. };
  114. template <class Range>
  115. class RangeEnumerator {
  116. Range r_;
  117. using BeginIteratorType = decltype(std::declval<Range>().begin());
  118. using EndIteratorType = decltype(std::declval<Range>().end());
  119. public:
  120. // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
  121. constexpr explicit RangeEnumerator(Range&& r) : r_(std::forward<Range>(r)) {}
  122. constexpr Enumerator<BeginIteratorType> begin() {
  123. return Enumerator<BeginIteratorType>(r_.begin());
  124. }
  125. constexpr Enumerator<EndIteratorType> end() {
  126. return Enumerator<EndIteratorType>(r_.end());
  127. }
  128. };
  129. } // namespace detail
  130. template <class Range>
  131. constexpr detail::RangeEnumerator<Range> enumerate(Range&& r) {
  132. return detail::RangeEnumerator<Range>(std::forward<Range>(r));
  133. }
  134. } // namespace c10
  135. #else
  136. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  137. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)