OptionalArrayRef.h 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // This file defines OptionalArrayRef<T>, a class that has almost the same
  3. // exact functionality as std::optional<ArrayRef<T>>, except that its
  4. // converting constructor fixes a dangling pointer issue.
  5. //
  6. // The implicit converting constructor of both std::optional<ArrayRef<T>> and
  7. // std::optional<ArrayRef<T>> can cause the underlying ArrayRef<T> to store
  8. // a dangling pointer. OptionalArrayRef<T> prevents this by wrapping
  9. // a std::optional<ArrayRef<T>> and fixing the constructor implementation.
  10. //
  11. // See https://github.com/pytorch/pytorch/issues/63645 for more on this.
  12. #pragma once
  13. #include <c10/util/ArrayRef.h>
  14. #include <cstdint>
  15. #include <initializer_list>
  16. #include <optional>
  17. #include <type_traits>
  18. #include <utility>
  19. namespace c10 {
  20. template <typename T>
  21. class OptionalArrayRef final {
  22. public:
  23. // Constructors
  24. constexpr OptionalArrayRef() noexcept = default;
  25. constexpr OptionalArrayRef(std::nullopt_t /*unused*/) noexcept {}
  26. OptionalArrayRef(const OptionalArrayRef& other) = default;
  27. OptionalArrayRef(OptionalArrayRef&& other) noexcept = default;
  28. constexpr OptionalArrayRef(const std::optional<ArrayRef<T>>& other) noexcept
  29. : wrapped_opt_array_ref(other) {}
  30. constexpr OptionalArrayRef(std::optional<ArrayRef<T>>&& other) noexcept
  31. : wrapped_opt_array_ref(std::move(other)) {}
  32. constexpr OptionalArrayRef(const T& value) noexcept
  33. : wrapped_opt_array_ref(value) {}
  34. template <
  35. typename U = ArrayRef<T>,
  36. std::enable_if_t<
  37. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  38. !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
  39. std::is_constructible_v<ArrayRef<T>, U&&> &&
  40. std::is_convertible_v<U&&, ArrayRef<T>> &&
  41. !std::is_convertible_v<U&&, T>,
  42. bool> = false>
  43. constexpr OptionalArrayRef(U&& value) noexcept(
  44. std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
  45. : wrapped_opt_array_ref(std::forward<U>(value)) {}
  46. template <
  47. typename U = ArrayRef<T>,
  48. std::enable_if_t<
  49. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  50. !std::is_same_v<std::decay_t<U>, std::in_place_t> &&
  51. std::is_constructible_v<ArrayRef<T>, U&&> &&
  52. !std::is_convertible_v<U&&, ArrayRef<T>>,
  53. bool> = false>
  54. constexpr explicit OptionalArrayRef(U&& value) noexcept(
  55. std::is_nothrow_constructible_v<ArrayRef<T>, U&&>)
  56. : wrapped_opt_array_ref(std::forward<U>(value)) {}
  57. template <typename... Args>
  58. constexpr explicit OptionalArrayRef(
  59. std::in_place_t ip,
  60. Args&&... args) noexcept
  61. : wrapped_opt_array_ref(ip, std::forward<Args>(args)...) {}
  62. template <typename U, typename... Args>
  63. constexpr explicit OptionalArrayRef(
  64. std::in_place_t ip,
  65. std::initializer_list<U> il,
  66. Args&&... args)
  67. : wrapped_opt_array_ref(ip, il, std::forward<Args>(args)...) {}
  68. constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
  69. : wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
  70. // Destructor
  71. ~OptionalArrayRef() = default;
  72. // Assignment
  73. constexpr OptionalArrayRef& operator=(std::nullopt_t /*unused*/) noexcept {
  74. wrapped_opt_array_ref = std::nullopt;
  75. return *this;
  76. }
  77. OptionalArrayRef& operator=(const OptionalArrayRef& other) = default;
  78. OptionalArrayRef& operator=(OptionalArrayRef&& other) noexcept = default;
  79. constexpr OptionalArrayRef& operator=(
  80. const std::optional<ArrayRef<T>>& other) noexcept {
  81. wrapped_opt_array_ref = other;
  82. return *this;
  83. }
  84. constexpr OptionalArrayRef& operator=(
  85. std::optional<ArrayRef<T>>&& other) noexcept {
  86. wrapped_opt_array_ref = std::move(other);
  87. return *this;
  88. }
  89. template <
  90. typename U = ArrayRef<T>,
  91. typename = std::enable_if_t<
  92. !std::is_same_v<std::decay_t<U>, OptionalArrayRef> &&
  93. std::is_constructible_v<ArrayRef<T>, U&&> &&
  94. std::is_assignable_v<ArrayRef<T>&, U&&>>>
  95. constexpr OptionalArrayRef& operator=(U&& value) noexcept(
  96. std::is_nothrow_constructible_v<ArrayRef<T>, U&&> &&
  97. std::is_nothrow_assignable_v<ArrayRef<T>&, U&&>) {
  98. wrapped_opt_array_ref = std::forward<U>(value);
  99. return *this;
  100. }
  101. // Observers
  102. constexpr ArrayRef<T>* operator->() noexcept {
  103. return &wrapped_opt_array_ref.value();
  104. }
  105. constexpr const ArrayRef<T>* operator->() const noexcept {
  106. return &wrapped_opt_array_ref.value();
  107. }
  108. constexpr ArrayRef<T>& operator*() & noexcept {
  109. return wrapped_opt_array_ref.value();
  110. }
  111. constexpr const ArrayRef<T>& operator*() const& noexcept {
  112. return wrapped_opt_array_ref.value();
  113. }
  114. constexpr ArrayRef<T>&& operator*() && noexcept {
  115. return std::move(wrapped_opt_array_ref.value());
  116. }
  117. constexpr const ArrayRef<T>&& operator*() const&& noexcept {
  118. return std::move(wrapped_opt_array_ref.value());
  119. }
  120. constexpr explicit operator bool() const noexcept {
  121. return wrapped_opt_array_ref.has_value();
  122. }
  123. constexpr bool has_value() const noexcept {
  124. return wrapped_opt_array_ref.has_value();
  125. }
  126. constexpr ArrayRef<T>& value() & {
  127. return wrapped_opt_array_ref.value();
  128. }
  129. constexpr const ArrayRef<T>& value() const& {
  130. // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
  131. return wrapped_opt_array_ref.value();
  132. }
  133. constexpr ArrayRef<T>&& value() && {
  134. return std::move(wrapped_opt_array_ref.value());
  135. }
  136. constexpr const ArrayRef<T>&& value() const&& {
  137. return std::move(wrapped_opt_array_ref.value());
  138. }
  139. template <typename U>
  140. constexpr std::
  141. enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
  142. value_or(U&& default_value) const& {
  143. return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
  144. }
  145. template <typename U>
  146. constexpr std::
  147. enable_if_t<std::is_convertible_v<U&&, ArrayRef<T>>, ArrayRef<T>>
  148. value_or(U&& default_value) && {
  149. return wrapped_opt_array_ref.value_or(std::forward<U>(default_value));
  150. }
  151. // Modifiers
  152. constexpr void swap(OptionalArrayRef& other) noexcept {
  153. std::swap(wrapped_opt_array_ref, other.wrapped_opt_array_ref);
  154. }
  155. constexpr void reset() noexcept {
  156. wrapped_opt_array_ref.reset();
  157. }
  158. template <typename... Args>
  159. constexpr std::
  160. enable_if_t<std::is_constructible_v<ArrayRef<T>, Args&&...>, ArrayRef<T>&>
  161. emplace(Args&&... args) noexcept(
  162. std::is_nothrow_constructible_v<ArrayRef<T>, Args&&...>) {
  163. return wrapped_opt_array_ref.emplace(std::forward<Args>(args)...);
  164. }
  165. template <typename U, typename... Args>
  166. constexpr ArrayRef<T>& emplace(
  167. std::initializer_list<U> il,
  168. Args&&... args) noexcept {
  169. return wrapped_opt_array_ref.emplace(il, std::forward<Args>(args)...);
  170. }
  171. private:
  172. std::optional<ArrayRef<T>> wrapped_opt_array_ref;
  173. };
  174. using OptionalIntArrayRef = OptionalArrayRef<int64_t>;
  175. inline bool operator==(
  176. const OptionalIntArrayRef& a1,
  177. const IntArrayRef& other) {
  178. if (!a1.has_value()) {
  179. return false;
  180. }
  181. return a1.value() == other;
  182. }
  183. inline bool operator==(
  184. const c10::IntArrayRef& a1,
  185. const c10::OptionalIntArrayRef& a2) {
  186. return a2 == a1;
  187. }
  188. } // namespace c10
  189. #else
  190. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  191. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)