TensorGeometry.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/TensorBase.h>
  4. #include <c10/core/WrapDimMinimal.h>
  5. namespace at {
  6. // Return if the tensor geometry represented by `sizes` and `strides` is
  7. // contiguous Although we cache is_contiguous in tensor now, this is till useful
  8. // because it allows checking if a particular geometry is contiguous without
  9. // explicitly constructing a tensor, e.g., when you want to choose a kernel
  10. // strategy based on whether a subgeometry is contiguous.
  11. TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
  12. struct TORCH_API TensorGeometry {
  13. TensorGeometry() = default;
  14. explicit TensorGeometry(c10::SymIntArrayRef sizes)
  15. : sizes_(sizes.vec()),
  16. strides_(sizes.size()),
  17. has_symbolic_sizes_strides_(
  18. !c10::asIntArrayRefSlowOpt(sizes).has_value()) {
  19. int64_t dim = static_cast<int64_t>(sizes.size());
  20. c10::SymInt expected_stride = 1;
  21. for (int64_t i = dim - 1; i >= 0; i--) {
  22. strides_[i] = expected_stride;
  23. expected_stride *= sizes_[i];
  24. }
  25. numel_ = expected_stride;
  26. }
  27. explicit TensorGeometry(const TensorBase& t)
  28. : sizes_(t.sym_sizes().vec()),
  29. strides_(t.sym_strides().vec()),
  30. storage_offset_(t.sym_storage_offset()),
  31. numel_(t.sym_numel()),
  32. has_symbolic_sizes_strides_(
  33. t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
  34. explicit TensorGeometry(
  35. std::vector<at::SymInt> sizes,
  36. std::vector<at::SymInt> strides,
  37. at::SymInt storage_offset)
  38. : sizes_(std::move(sizes)),
  39. strides_(std::move(strides)),
  40. storage_offset_(std::move(storage_offset)) {
  41. recompute();
  42. }
  43. // true if the tensor is contiguous
  44. bool is_contiguous() const;
  45. int64_t dim() const {
  46. return static_cast<int64_t>(sizes_.size());
  47. }
  48. int64_t size(int64_t dim) const {
  49. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  50. dim = c10::maybe_wrap_dim(dim, this->dim());
  51. return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
  52. }
  53. c10::IntArrayRef sizes() const {
  54. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  55. return c10::asIntArrayRefUnchecked(sizes_);
  56. }
  57. int64_t stride(int64_t dim) const {
  58. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  59. dim = c10::maybe_wrap_dim(dim, this->dim());
  60. return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
  61. }
  62. c10::IntArrayRef strides() const {
  63. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  64. return c10::asIntArrayRefUnchecked(strides_);
  65. }
  66. int64_t storage_offset() const {
  67. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  68. return storage_offset_.as_int_unchecked();
  69. }
  70. int64_t numel() const {
  71. TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
  72. return numel_.as_int_unchecked();
  73. }
  74. c10::SymInt sym_size(int64_t dim) const {
  75. dim = c10::maybe_wrap_dim(dim, this->dim());
  76. return sizes_.at(static_cast<size_t>(dim));
  77. }
  78. c10::SymIntArrayRef sym_sizes() const {
  79. return sizes_;
  80. }
  81. c10::SymInt sym_stride(int64_t dim) const {
  82. dim = c10::maybe_wrap_dim(dim, this->dim());
  83. return strides_.at(static_cast<size_t>(dim));
  84. }
  85. c10::SymIntArrayRef sym_strides() const {
  86. return strides_;
  87. }
  88. c10::SymInt sym_storage_offset() const {
  89. return storage_offset_;
  90. }
  91. c10::SymInt sym_numel() const {
  92. return numel_;
  93. }
  94. TensorGeometry transpose(int64_t dim0, int64_t dim1) {
  95. TensorGeometry r = *this; // copy
  96. TORCH_CHECK(
  97. dim0 < dim(),
  98. "transpose: dim0=",
  99. dim0,
  100. " out of range (dim=",
  101. dim(),
  102. ")")
  103. TORCH_CHECK(
  104. dim1 < dim(),
  105. "transpose: dim1=",
  106. dim1,
  107. " out of range (dim=",
  108. dim(),
  109. ")")
  110. std::swap(r.sizes_[dim0], r.sizes_[dim1]);
  111. std::swap(r.strides_[dim0], r.strides_[dim1]);
  112. return r;
  113. }
  114. std::vector<c10::SymInt>& mutable_sizes() {
  115. return sizes_;
  116. }
  117. std::vector<c10::SymInt>& mutable_strides() {
  118. return strides_;
  119. }
  120. c10::SymInt& mutable_storage_offset() {
  121. return storage_offset_;
  122. }
  123. void recompute() {
  124. // recalculate numel after a change
  125. c10::SymInt numel = 1;
  126. for (const auto& i : sizes_) {
  127. numel = numel * i;
  128. }
  129. numel_ = std::move(numel);
  130. has_symbolic_sizes_strides_ =
  131. !c10::asIntArrayRefSlowOpt(sizes_).has_value();
  132. }
  133. private:
  134. std::vector<c10::SymInt> sizes_;
  135. std::vector<c10::SymInt> strides_;
  136. c10::SymInt storage_offset_;
  137. c10::SymInt numel_;
  138. bool has_symbolic_sizes_strides_{false};
  139. };
  140. } // namespace at
  141. #else
  142. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  143. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)