EmptyTensor.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/TensorBase.h>
  4. namespace at::detail {
  5. inline void check_size_nonnegative(ArrayRef<int64_t> size) {
  6. for (const auto& x : size) {
  7. TORCH_CHECK(
  8. x >= 0,
  9. "Trying to create tensor with negative dimension ",
  10. x,
  11. ": ",
  12. size);
  13. }
  14. }
  15. inline void check_size_nonnegative(ArrayRef<c10::SymInt> size) {
  16. for (const auto& x : size) {
  17. TORCH_SYM_CHECK(
  18. x.sym_ge(0),
  19. "Trying to create tensor with negative dimension ",
  20. x,
  21. ": ",
  22. size);
  23. }
  24. }
  25. TORCH_API size_t computeStorageNbytesContiguous(
  26. IntArrayRef sizes,
  27. size_t itemsize,
  28. size_t storage_offset = 0);
  29. TORCH_API SymInt computeStorageNbytesContiguous(
  30. SymIntArrayRef sizes,
  31. const SymInt& itemsize,
  32. const SymInt& storage_offset = 0);
  33. TORCH_API size_t computeStorageNbytes(
  34. IntArrayRef sizes,
  35. IntArrayRef strides,
  36. size_t itemsize,
  37. size_t storage_offset = 0);
  38. TORCH_API SymInt computeStorageNbytes(
  39. SymIntArrayRef sizes,
  40. SymIntArrayRef strides,
  41. const SymInt& itemsize,
  42. const SymInt& storage_offset = 0);
  43. TORCH_API TensorBase empty_generic(
  44. IntArrayRef size,
  45. c10::Allocator* allocator,
  46. c10::DispatchKeySet ks,
  47. ScalarType scalar_type,
  48. std::optional<c10::MemoryFormat> memory_format_opt);
  49. TORCH_API TensorBase empty_generic_symint(
  50. SymIntArrayRef size,
  51. c10::Allocator* allocator,
  52. c10::DispatchKeySet ks,
  53. ScalarType scalar_type,
  54. std::optional<c10::MemoryFormat> memory_format_opt);
  55. TORCH_API TensorBase empty_strided_generic(
  56. IntArrayRef size,
  57. IntArrayRef stride,
  58. c10::Allocator* allocator,
  59. c10::DispatchKeySet ks,
  60. ScalarType scalar_type);
  61. TORCH_API TensorBase empty_strided_symint_generic(
  62. SymIntArrayRef size,
  63. SymIntArrayRef stride,
  64. c10::Allocator* allocator,
  65. c10::DispatchKeySet ks,
  66. ScalarType scalar_type);
  67. TORCH_API TensorBase empty_cpu(
  68. IntArrayRef size,
  69. ScalarType dtype,
  70. bool pin_memory = false,
  71. std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
  72. TORCH_API TensorBase empty_cpu(
  73. IntArrayRef size,
  74. std::optional<ScalarType> dtype_opt,
  75. std::optional<Layout> layout_opt,
  76. std::optional<Device> device_opt,
  77. std::optional<bool> pin_memory_opt,
  78. std::optional<c10::MemoryFormat> memory_format_opt);
  79. TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options);
  80. TORCH_API TensorBase empty_strided_cpu(
  81. IntArrayRef size,
  82. IntArrayRef stride,
  83. ScalarType dtype,
  84. bool pin_memory = false);
  85. TORCH_API TensorBase empty_strided_cpu(
  86. IntArrayRef size,
  87. IntArrayRef stride,
  88. std::optional<ScalarType> dtype_opt,
  89. std::optional<Layout> layout_opt,
  90. std::optional<Device> device_opt,
  91. std::optional<bool> pin_memory_opt);
  92. TORCH_API TensorBase empty_strided_cpu(
  93. IntArrayRef size,
  94. IntArrayRef stride,
  95. const TensorOptions& options);
  96. TORCH_API TensorBase empty_meta(
  97. IntArrayRef size,
  98. ScalarType dtype,
  99. std::optional<c10::MemoryFormat> memory_format_opt = std::nullopt);
  100. TORCH_API TensorBase empty_meta(
  101. IntArrayRef size,
  102. std::optional<ScalarType> dtype_opt,
  103. std::optional<Layout> layout_opt,
  104. std::optional<Device> device_opt,
  105. std::optional<bool> pin_memory_opt,
  106. std::optional<c10::MemoryFormat> memory_format_opt);
  107. TORCH_API TensorBase empty_symint_meta(
  108. SymIntArrayRef size,
  109. std::optional<ScalarType> dtype_opt,
  110. std::optional<Layout> layout_opt,
  111. std::optional<Device> device_opt,
  112. std::optional<bool> pin_memory_opt,
  113. std::optional<c10::MemoryFormat> memory_format_opt);
  114. TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options);
  115. TORCH_API TensorBase
  116. empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype);
  117. TORCH_API TensorBase empty_strided_meta(
  118. IntArrayRef size,
  119. IntArrayRef stride,
  120. std::optional<ScalarType> dtype_opt,
  121. std::optional<Layout> layout_opt,
  122. std::optional<Device> device_opt,
  123. std::optional<bool> pin_memory_opt);
  124. TORCH_API TensorBase empty_strided_meta(
  125. IntArrayRef size,
  126. IntArrayRef stride,
  127. const TensorOptions& options);
  128. TORCH_API TensorBase empty_strided_symint_meta(
  129. SymIntArrayRef size,
  130. SymIntArrayRef stride,
  131. ScalarType dtype);
  132. TORCH_API TensorBase empty_strided_symint_meta(
  133. SymIntArrayRef size,
  134. SymIntArrayRef stride,
  135. std::optional<ScalarType> dtype_opt,
  136. std::optional<Layout> layout_opt,
  137. std::optional<Device> device_opt);
  138. TORCH_API TensorBase empty_strided_symint_meta(
  139. SymIntArrayRef size,
  140. SymIntArrayRef stride,
  141. const TensorOptions& options);
  142. } // namespace at::detail
  143. #else
  144. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  145. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)