Bitset.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cstddef>
  4. #if defined(_MSC_VER)
  5. #include <intrin.h>
  6. #endif
  7. namespace c10::utils {
  8. /**
  9. * This is a simple bitset class with sizeof(long long int) bits.
  10. * You can set bits, unset bits, query bits by index,
  11. * and query for the first set bit.
  12. * Before using this class, please also take a look at std::bitset,
  13. * which has more functionality and is more generic. It is probably
  14. * a better fit for your use case. The sole reason for c10::utils::bitset
  15. * to exist is that std::bitset misses a find_first_set() method.
  16. */
  17. struct bitset final {
  18. private:
  19. #if defined(_MSC_VER)
  20. // MSVCs _BitScanForward64 expects int64_t
  21. using bitset_type = int64_t;
  22. #else
  23. // POSIX ffsll expects long long int
  24. using bitset_type = long long int;
  25. #endif
  26. public:
  27. static constexpr size_t NUM_BITS() {
  28. return 8 * sizeof(bitset_type);
  29. }
  30. constexpr bitset() noexcept = default;
  31. constexpr bitset(const bitset&) noexcept = default;
  32. constexpr bitset(bitset&&) noexcept = default;
  33. // there is an issue for gcc 5.3.0 when define default function as constexpr
  34. // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
  35. bitset& operator=(const bitset&) noexcept = default;
  36. bitset& operator=(bitset&&) noexcept = default;
  37. ~bitset() = default;
  38. constexpr void set(size_t index) noexcept {
  39. bitset_ |= (static_cast<long long int>(1) << index);
  40. }
  41. constexpr void unset(size_t index) noexcept {
  42. bitset_ &= ~(static_cast<long long int>(1) << index);
  43. }
  44. constexpr bool get(size_t index) const noexcept {
  45. return bitset_ & (static_cast<long long int>(1) << index);
  46. }
  47. constexpr bool is_entirely_unset() const noexcept {
  48. return 0 == bitset_;
  49. }
  50. // Call the given functor with the index of each bit that is set
  51. template <class Func>
  52. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  53. void for_each_set_bit(Func&& func) const {
  54. bitset cur = *this;
  55. size_t index = cur.find_first_set();
  56. while (0 != index) {
  57. // -1 because find_first_set() is not one-indexed.
  58. index -= 1;
  59. func(index);
  60. cur.unset(index);
  61. index = cur.find_first_set();
  62. }
  63. }
  64. private:
  65. // Return the index of the first set bit. The returned index is one-indexed
  66. // (i.e. if the very first bit is set, this function returns '1'), and a
  67. // return of '0' means that there was no bit set.
  68. size_t find_first_set() const {
  69. #if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
  70. unsigned long result;
  71. bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
  72. if (!has_bits_set) {
  73. return 0;
  74. }
  75. return result + 1;
  76. #elif defined(_MSC_VER) && defined(_M_IX86)
  77. unsigned long result;
  78. if (static_cast<uint32_t>(bitset_) != 0) {
  79. bool has_bits_set =
  80. (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
  81. if (!has_bits_set) {
  82. return 0;
  83. }
  84. return result + 1;
  85. } else {
  86. bool has_bits_set =
  87. (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
  88. if (!has_bits_set) {
  89. return 32;
  90. }
  91. return result + 33;
  92. }
  93. #else
  94. return __builtin_ffsll(bitset_);
  95. #endif
  96. }
  97. friend bool operator==(bitset lhs, bitset rhs) noexcept {
  98. return lhs.bitset_ == rhs.bitset_;
  99. }
  100. bitset_type bitset_{0};
  101. };
  102. inline bool operator!=(bitset lhs, bitset rhs) noexcept {
  103. return !(lhs == rhs);
  104. }
  105. } // namespace c10::utils
  106. #else
  107. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  108. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)