Utils.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/EmptyTensor.h>
  4. #include <ATen/Formatting.h>
  5. #include <ATen/core/ATenGeneral.h>
  6. #include <ATen/core/Generator.h>
  7. #include <c10/core/ScalarType.h>
  8. #include <c10/core/StorageImpl.h>
  9. #include <c10/core/UndefinedTensorImpl.h>
  10. #include <c10/util/ArrayRef.h>
  11. #include <c10/util/Exception.h>
  12. #include <c10/util/accumulate.h>
  13. #include <c10/util/irange.h>
  14. #include <algorithm>
  15. #define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
  16. TypeName(const TypeName&) = delete; \
  17. void operator=(const TypeName&) = delete
  18. namespace at {
  19. TORCH_API int _crash_if_asan(int /*arg*/);
  20. // Converts a TensorList (i.e. ArrayRef<Tensor> to vector of TensorImpl*)
  21. // NB: This is ONLY used by legacy TH bindings, and ONLY used by cat.
  22. // Once cat is ported entirely to ATen this can be deleted!
  23. inline std::vector<TensorImpl*> checked_dense_tensor_list_unwrap(
  24. ArrayRef<Tensor> tensors,
  25. const char* name,
  26. int pos,
  27. c10::DeviceType device_type,
  28. ScalarType scalar_type) {
  29. std::vector<TensorImpl*> unwrapped;
  30. unwrapped.reserve(tensors.size());
  31. for (const auto i : c10::irange(tensors.size())) {
  32. const auto& expr = tensors[i];
  33. if (expr.layout() != Layout::Strided) {
  34. TORCH_CHECK(
  35. false,
  36. "Expected dense tensor but got ",
  37. expr.layout(),
  38. " for sequence element ",
  39. i,
  40. " in sequence argument at position #",
  41. pos,
  42. " '",
  43. name,
  44. "'");
  45. }
  46. if (expr.device().type() != device_type) {
  47. TORCH_CHECK(
  48. false,
  49. "Expected object of device type ",
  50. device_type,
  51. " but got device type ",
  52. expr.device().type(),
  53. " for sequence element ",
  54. i,
  55. " in sequence argument at position #",
  56. pos,
  57. " '",
  58. name,
  59. "'");
  60. }
  61. if (expr.scalar_type() != scalar_type) {
  62. TORCH_CHECK(
  63. false,
  64. "Expected object of scalar type ",
  65. scalar_type,
  66. " but got scalar type ",
  67. expr.scalar_type(),
  68. " for sequence element ",
  69. i,
  70. " in sequence argument at position #",
  71. pos,
  72. " '",
  73. name,
  74. "'");
  75. }
  76. unwrapped.emplace_back(expr.unsafeGetTensorImpl());
  77. }
  78. return unwrapped;
  79. }
  80. template <size_t N>
  81. std::array<int64_t, N> check_intlist(
  82. ArrayRef<int64_t> list,
  83. const char* name,
  84. int pos) {
  85. if (list.empty()) {
  86. // TODO: is this necessary? We used to treat nullptr-vs-not in IntList
  87. // differently with strides as a way of faking optional.
  88. list = {};
  89. }
  90. auto res = std::array<int64_t, N>();
  91. if (list.size() == 1 && N > 1) {
  92. res.fill(list[0]);
  93. return res;
  94. }
  95. if (list.size() != N) {
  96. TORCH_CHECK(
  97. false,
  98. "Expected a list of ",
  99. N,
  100. " ints but got ",
  101. list.size(),
  102. " for argument #",
  103. pos,
  104. " '",
  105. name,
  106. "'");
  107. }
  108. std::copy_n(list.begin(), N, res.begin());
  109. return res;
  110. }
  111. using at::detail::check_size_nonnegative;
  112. namespace detail {
  113. template <typename T>
  114. TORCH_API Tensor tensor_cpu(ArrayRef<T> values, const TensorOptions& options);
  115. template <typename T>
  116. TORCH_API Tensor
  117. tensor_backend(ArrayRef<T> values, const TensorOptions& options);
  118. template <typename T>
  119. TORCH_API Tensor
  120. tensor_complex_cpu(ArrayRef<T> values, const TensorOptions& options);
  121. template <typename T>
  122. TORCH_API Tensor
  123. tensor_complex_backend(ArrayRef<T> values, const TensorOptions& options);
  124. } // namespace detail
  125. } // namespace at
  126. #else
  127. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  128. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)