ArrayRef.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. //===--- ArrayRef.h - Array Reference Wrapper -------------------*- C++ -*-===//
  3. //
  4. // The LLVM Compiler Infrastructure
  5. //
  6. // This file is distributed under the University of Illinois Open Source
  7. // License. See LICENSE.TXT for details.
  8. //
  9. //===----------------------------------------------------------------------===//
  10. // ATen: modified from llvm::ArrayRef.
  11. // removed llvm-specific functionality
  12. // removed some implicit const -> non-const conversions that rely on
  13. // complicated std::enable_if meta-programming
  14. // removed a bunch of slice variants for simplicity...
  15. #pragma once
  16. #include <c10/macros/Macros.h>
  17. #include <c10/util/Exception.h>
  18. #include <c10/util/SmallVector.h>
  19. #include <torch/headeronly/util/HeaderOnlyArrayRef.h>
  20. #include <array>
  21. #include <cstddef>
  22. #include <cstdint>
  23. #include <initializer_list>
  24. #include <iterator>
  25. #include <ostream>
  26. #include <type_traits>
  27. #include <vector>
  28. namespace c10 {
  29. /// ArrayRef - Represent a constant reference to an array (0 or more elements
  30. /// consecutively in memory), i.e. a start pointer and a length. It allows
  31. /// various APIs to take consecutive elements easily and conveniently.
  32. ///
  33. /// This class does not own the underlying data, it is expected to be used in
  34. /// situations where the data resides in some other buffer, whose lifetime
  35. /// extends past that of the ArrayRef. For this reason, it is not in general
  36. /// safe to store an ArrayRef.
  37. ///
  38. /// This is intended to be trivially copyable, so it should be passed by
  39. /// value.
  40. ///
  41. /// NOTE: We have refactored out the headeronly parts of the ArrayRef struct
  42. /// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of
  43. /// the underlying constexpr calls, we rely on apparent-type dispatch for
  44. /// inheritance. This should be fine because their memory format is the same,
  45. /// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods.
  46. /// However, you should prefer to use ArrayRef when possible, because its use
  47. /// of TORCH_CHECK will lead to better user-facing error messages.
  48. template <typename T>
  49. // ArrayRef cannot be derived from. Normally, we would use `final`
  50. // specifier to force this constraint at compile time. However, Intel
  51. // compiler does not recognize ArrayRef as a class template (which is
  52. // required in the definition of at::TensorAccessor, for instance)
  53. // when `final` specifier is used. So, we cannot define ArrayRef as
  54. // final because of the Intel compiler issue.
  55. class ArrayRef : public HeaderOnlyArrayRef<T> {
  56. public:
  57. /// @name Constructors, all inherited from HeaderOnlyArrayRef except for
  58. /// SmallVector. As inherited constructors won't work with class template
  59. /// argument deduction (CTAD) until C++23, we add deduction guides after
  60. /// the class definition to enable CTAD.
  61. /// @{
  62. using HeaderOnlyArrayRef<T>::HeaderOnlyArrayRef;
  63. /// Construct an ArrayRef from a SmallVector. This is templated in order to
  64. /// avoid instantiating SmallVectorTemplateCommon<T> whenever we
  65. /// copy-construct an ArrayRef.
  66. /// NOTE: this is the only constructor that is not inherited from
  67. /// HeaderOnlyArrayRef.
  68. template <typename U>
  69. /* implicit */ ArrayRef(const SmallVectorTemplateCommon<T, U>& Vec)
  70. : HeaderOnlyArrayRef<T>(Vec.data(), Vec.size()) {}
  71. /// @}
  72. /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef
  73. /// @{
  74. /// front - Get the first element.
  75. /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
  76. /// STD_TORCH_CHECK
  77. constexpr const T& front() const {
  78. TORCH_CHECK(
  79. !this->empty(), "ArrayRef: attempted to access front() of empty list");
  80. return this->Data[0];
  81. }
  82. /// back - Get the last element.
  83. /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
  84. /// STD_TORCH_CHECK
  85. constexpr const T& back() const {
  86. TORCH_CHECK(
  87. !this->empty(), "ArrayRef: attempted to access back() of empty list");
  88. return this->Data[this->Length - 1];
  89. }
  90. /// slice(n, m) - Take M elements of the array starting at element N
  91. /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
  92. /// STD_TORCH_CHECK
  93. constexpr ArrayRef<T> slice(size_t N, size_t M) const {
  94. TORCH_CHECK(
  95. N + M <= this->size(),
  96. "ArrayRef: invalid slice, N = ",
  97. N,
  98. "; M = ",
  99. M,
  100. "; size = ",
  101. this->size());
  102. return ArrayRef<T>(this->data() + N, M);
  103. }
  104. /// slice(n) - Chop off the first N elements of the array.
  105. /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
  106. /// STD_TORCH_CHECK
  107. constexpr ArrayRef<T> slice(size_t N) const {
  108. TORCH_CHECK(
  109. N <= this->size(),
  110. "ArrayRef: invalid slice, N = ",
  111. N,
  112. "; size = ",
  113. this->size());
  114. return slice(N, this->size() - N); // should this slice be this->slice?
  115. }
  116. /// @}
  117. /// @name Operator Overloads
  118. /// @{
  119. /// Vector compatibility
  120. /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of
  121. /// STD_TORCH_CHECK
  122. constexpr const T& at(size_t Index) const {
  123. TORCH_CHECK(
  124. Index < this->Length,
  125. "ArrayRef: invalid index Index = ",
  126. Index,
  127. "; Length = ",
  128. this->Length);
  129. return this->Data[Index];
  130. }
  131. /// Disallow accidental assignment from a temporary.
  132. ///
  133. /// The declaration here is extra complicated so that "arrayRef = {}"
  134. /// continues to select the move assignment operator.
  135. template <typename U>
  136. std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
  137. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  138. U&& Temporary) = delete;
  139. /// Disallow accidental assignment from a temporary.
  140. ///
  141. /// The declaration here is extra complicated so that "arrayRef = {}"
  142. /// continues to select the move assignment operator.
  143. template <typename U>
  144. std::enable_if_t<std::is_same_v<U, T>, ArrayRef<T>>& operator=(
  145. std::initializer_list<U>) = delete;
  146. /// @}
  147. };
  148. /// Deduction guides for ArrayRef to support CTAD with inherited constructors
  149. /// These mirror the constructors inherited from HeaderOnlyArrayRef
  150. /// @{
  151. // Single element constructor
  152. template <typename T>
  153. ArrayRef(const T&) -> ArrayRef<T>;
  154. // Pointer and length constructor
  155. template <typename T>
  156. ArrayRef(const T*, size_t) -> ArrayRef<T>;
  157. // Range constructor (begin, end)
  158. template <typename T>
  159. ArrayRef(const T*, const T*) -> ArrayRef<T>;
  160. // Generic container constructor (anything with .data() and .size())
  161. template <typename Container>
  162. ArrayRef(const Container&) -> ArrayRef<
  163. std::remove_pointer_t<decltype(std::declval<Container>().data())>>;
  164. // std::vector constructor
  165. template <typename T, typename A>
  166. ArrayRef(const std::vector<T, A>&) -> ArrayRef<T>;
  167. // std::array constructor
  168. template <typename T, size_t N>
  169. ArrayRef(const std::array<T, N>&) -> ArrayRef<T>;
  170. // C array constructor
  171. template <typename T, size_t N>
  172. ArrayRef(const T (&)[N]) -> ArrayRef<T>;
  173. // std::initializer_list constructor
  174. template <typename T>
  175. ArrayRef(const std::initializer_list<T>&) -> ArrayRef<T>;
  176. /// @}
  177. template <typename T>
  178. std::ostream& operator<<(std::ostream& out, ArrayRef<T> list) {
  179. int i = 0;
  180. out << '[';
  181. for (const auto& e : list) {
  182. if (i++ > 0)
  183. out << ", ";
  184. out << e;
  185. }
  186. out << ']';
  187. return out;
  188. }
  189. /// @name ArrayRef Convenience constructors
  190. /// @{
  191. /// Construct an ArrayRef from a single element.
  192. template <typename T>
  193. ArrayRef<T> makeArrayRef(const T& OneElt) {
  194. return OneElt;
  195. }
  196. /// Construct an ArrayRef from a pointer and length.
  197. template <typename T>
  198. ArrayRef<T> makeArrayRef(const T* data, size_t length) {
  199. return ArrayRef<T>(data, length);
  200. }
  201. /// Construct an ArrayRef from a range.
  202. template <typename T>
  203. ArrayRef<T> makeArrayRef(const T* begin, const T* end) {
  204. return ArrayRef<T>(begin, end);
  205. }
  206. /// Construct an ArrayRef from a SmallVector.
  207. template <typename T>
  208. ArrayRef<T> makeArrayRef(const SmallVectorImpl<T>& Vec) {
  209. return Vec;
  210. }
  211. /// Construct an ArrayRef from a SmallVector.
  212. template <typename T, unsigned N>
  213. ArrayRef<T> makeArrayRef(const SmallVector<T, N>& Vec) {
  214. return Vec;
  215. }
  216. /// Construct an ArrayRef from a std::vector.
  217. template <typename T>
  218. ArrayRef<T> makeArrayRef(const std::vector<T>& Vec) {
  219. return Vec;
  220. }
  221. /// Construct an ArrayRef from a std::array.
  222. template <typename T, std::size_t N>
  223. ArrayRef<T> makeArrayRef(const std::array<T, N>& Arr) {
  224. return Arr;
  225. }
  226. /// Construct an ArrayRef from an ArrayRef (no-op) (const)
  227. template <typename T>
  228. ArrayRef<T> makeArrayRef(const ArrayRef<T>& Vec) {
  229. return Vec;
  230. }
  231. /// Construct an ArrayRef from an ArrayRef (no-op)
  232. template <typename T>
  233. ArrayRef<T>& makeArrayRef(ArrayRef<T>& Vec) {
  234. return Vec;
  235. }
  236. /// Construct an ArrayRef from a C array.
  237. template <typename T, size_t N>
  238. // NOLINTNEXTLINE(*c-arrays*)
  239. ArrayRef<T> makeArrayRef(const T (&Arr)[N]) {
  240. return ArrayRef<T>(Arr);
  241. }
  242. // WARNING: Template instantiation will NOT be willing to do an implicit
  243. // conversions to get you to an c10::ArrayRef, which is why we need so
  244. // many overloads.
  245. template <typename T>
  246. bool operator==(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
  247. return a1.equals(a2);
  248. }
  249. template <typename T>
  250. bool operator!=(c10::ArrayRef<T> a1, c10::ArrayRef<T> a2) {
  251. return !a1.equals(a2);
  252. }
  253. template <typename T>
  254. bool operator==(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
  255. return c10::ArrayRef<T>(a1).equals(a2);
  256. }
  257. template <typename T>
  258. bool operator!=(const std::vector<T>& a1, c10::ArrayRef<T> a2) {
  259. return !c10::ArrayRef<T>(a1).equals(a2);
  260. }
  261. template <typename T>
  262. bool operator==(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
  263. return a1.equals(c10::ArrayRef<T>(a2));
  264. }
  265. template <typename T>
  266. bool operator!=(c10::ArrayRef<T> a1, const std::vector<T>& a2) {
  267. return !a1.equals(c10::ArrayRef<T>(a2));
  268. }
  269. using IntArrayRef = ArrayRef<int64_t>;
  270. using IntList [[deprecated(
  271. "This alias is deprecated because it doesn't make ownership semantics obvious. Use IntArrayRef instead!")]] =
  272. ArrayRef<int64_t>;
  273. } // namespace c10
  274. #else
  275. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  276. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)