TensorUtils.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/DimVector.h>
  4. #include <ATen/EmptyTensor.h>
  5. #include <ATen/Tensor.h>
  6. #include <ATen/TensorGeometry.h>
  7. #include <ATen/Utils.h>
  8. #include <utility>
  9. // These functions are NOT in Utils.h, because this file has a dep on Tensor.h
  10. #define TORCH_CHECK_TENSOR_ALL(cond, ...) \
  11. TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
  12. namespace at {
  13. // The following are utility functions for checking that arguments
  14. // make sense. These are particularly useful for native functions,
  15. // which do NO argument checking by default.
  16. struct TORCH_API TensorArg {
  17. // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
  18. const Tensor& tensor;
  19. const char* name;
  20. int pos; // 1-indexed
  21. TensorArg(const Tensor& tensor, const char* name, int pos)
  22. : tensor(tensor), name(name), pos(pos) {}
  23. // Try to mitigate any possibility of dangling reference to temporaries.
  24. // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
  25. TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
  26. const Tensor* operator->() const {
  27. return &tensor;
  28. }
  29. const Tensor& operator*() const {
  30. return tensor;
  31. }
  32. };
  33. struct TORCH_API TensorGeometryArg {
  34. TensorGeometry tensor;
  35. const char* name;
  36. int pos; // 1-indexed
  37. /* implicit */ TensorGeometryArg(TensorArg arg)
  38. : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
  39. TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
  40. : tensor(std::move(tensor)), name(name), pos(pos) {}
  41. const TensorGeometry* operator->() const {
  42. return &tensor;
  43. }
  44. const TensorGeometry& operator*() const {
  45. return tensor;
  46. }
  47. };
  48. // A string describing which function did checks on its input
  49. // arguments.
  50. // TODO: Consider generalizing this into a call stack.
  51. using CheckedFrom = const char*;
  52. // The undefined convention: singular operators assume their arguments
  53. // are defined, but functions which take multiple tensors will
  54. // implicitly filter out undefined tensors (to make it easier to perform
  55. // tests which should apply if the tensor is defined, and should not
  56. // otherwise.)
  57. //
  58. // NB: This means that the n-ary operators take lists of TensorArg,
  59. // not TensorGeometryArg, because the Tensor to TensorGeometry
  60. // conversion will blow up if you have undefined tensors.
  61. TORCH_API std::ostream& operator<<(
  62. std::ostream& out,
  63. const TensorGeometryArg& t);
  64. TORCH_API void checkDim(
  65. CheckedFrom c,
  66. const Tensor& tensor,
  67. const char* name,
  68. int pos, // 1-indexed
  69. int64_t dim);
  70. TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
  71. // NB: this is an inclusive-exclusive range
  72. TORCH_API void checkDimRange(
  73. CheckedFrom c,
  74. const TensorGeometryArg& t,
  75. int64_t dim_start,
  76. int64_t dim_end);
  77. TORCH_API void checkSameDim(
  78. CheckedFrom c,
  79. const TensorGeometryArg& t1,
  80. const TensorGeometryArg& t2);
  81. TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
  82. TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
  83. TORCH_API void checkSize(
  84. CheckedFrom c,
  85. const TensorGeometryArg& t,
  86. IntArrayRef sizes);
  87. TORCH_API void checkSize_symint(
  88. CheckedFrom c,
  89. const TensorGeometryArg& t,
  90. c10::SymIntArrayRef sizes);
  91. TORCH_API void checkSize(
  92. CheckedFrom c,
  93. const TensorGeometryArg& t,
  94. int64_t dim,
  95. int64_t size);
  96. TORCH_API void checkSize_symint(
  97. CheckedFrom c,
  98. const TensorGeometryArg& t,
  99. int64_t dim,
  100. const c10::SymInt& size);
  101. TORCH_API void checkNumel(
  102. CheckedFrom c,
  103. const TensorGeometryArg& t,
  104. int64_t numel);
  105. TORCH_API void checkSameNumel(
  106. CheckedFrom c,
  107. const TensorArg& t1,
  108. const TensorArg& t2);
  109. TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
  110. TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
  111. TORCH_API void checkScalarTypes(
  112. CheckedFrom c,
  113. const TensorArg& t,
  114. at::ArrayRef<ScalarType> l);
  115. TORCH_API void checkSameGPU(
  116. CheckedFrom c,
  117. const TensorArg& t1,
  118. const TensorArg& t2);
  119. TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
  120. TORCH_API void checkSameType(
  121. CheckedFrom c,
  122. const TensorArg& t1,
  123. const TensorArg& t2);
  124. TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
  125. TORCH_API void checkSameSize(
  126. CheckedFrom c,
  127. const TensorArg& t1,
  128. const TensorArg& t2);
  129. TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors);
  130. TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
  131. TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
  132. // FixMe: does TensorArg slow things down?
  133. TORCH_API void checkBackend(
  134. CheckedFrom c,
  135. at::ArrayRef<Tensor> t,
  136. at::Backend backend);
  137. TORCH_API void checkDeviceType(
  138. CheckedFrom c,
  139. at::ArrayRef<Tensor> tensors,
  140. at::DeviceType device_type);
  141. TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
  142. TORCH_API void checkLayout(
  143. CheckedFrom c,
  144. at::ArrayRef<Tensor> tensors,
  145. at::Layout layout);
  146. // Methods for getting data_ptr if tensor is defined
  147. TORCH_API void* maybe_data_ptr(const Tensor& tensor);
  148. TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
  149. TORCH_API void check_dim_size(
  150. const Tensor& tensor,
  151. int64_t dim,
  152. int64_t dim_size,
  153. int64_t size);
  154. namespace detail {
  155. TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
  156. TORCH_API std::optional<std::vector<int64_t>> computeStride(
  157. IntArrayRef oldshape,
  158. IntArrayRef oldstride,
  159. IntArrayRef newshape);
  160. TORCH_API std::optional<SymDimVector> computeStride(
  161. c10::SymIntArrayRef oldshape,
  162. c10::SymIntArrayRef oldstride,
  163. c10::SymIntArrayRef newshape);
  164. TORCH_API std::optional<DimVector> computeStride(
  165. IntArrayRef oldshape,
  166. IntArrayRef oldstride,
  167. const DimVector& newshape);
  168. } // namespace detail
  169. } // namespace at
  170. #else
  171. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  172. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)