WrapDimUtils.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/IListRef.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <c10/core/TensorImpl.h>
  6. #include <c10/core/WrapDimMinimal.h>
  7. #include <c10/util/irange.h>
  8. namespace at {
  9. // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
  10. // range [-1, 0]. This is a special case for scalar tensors and manifests in
  11. // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
  12. // [-dim_post_expr, dim_post_expr-1].
  13. using c10::maybe_wrap_dim;
  14. inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
  15. return maybe_wrap_dim(dim, tensor->dim());
  16. }
  17. inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
  18. if (tensors.empty()) {
  19. // can't wrap empty TensorList; rely on underlying implementation to throw
  20. // error if necessary.
  21. return dim;
  22. }
  23. return maybe_wrap_dim(dim, tensors[0].dim());
  24. }
  25. inline int64_t maybe_wrap_dim(
  26. int64_t dim,
  27. const std::vector<std::vector<int64_t>>& tensor_sizes) {
  28. if (tensor_sizes.empty()) {
  29. // can't wrap empty list; rely on underlying implementation to throw error
  30. // if necessary
  31. return dim;
  32. }
  33. return maybe_wrap_dim(dim, static_cast<int64_t>(tensor_sizes[0].size()));
  34. }
  35. // Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
  36. // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
  37. // specified using negative indices.
  38. //
  39. // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
  40. // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
  41. // dimensions not in the range [-dim_post_expr, dim_post_expr).
  42. inline void maybe_wrap_dims_n(
  43. int64_t* dims,
  44. int64_t ndims,
  45. int64_t dim_post_expr,
  46. bool wrap_scalars = true) {
  47. if (dim_post_expr <= 0) {
  48. if (wrap_scalars) {
  49. dim_post_expr = 1; // this will make range [-1, 0]
  50. } else {
  51. TORCH_CHECK_INDEX(
  52. ndims == 0,
  53. "Dimension specified as ",
  54. dims[0],
  55. " but tensor has no dimensions");
  56. return;
  57. }
  58. }
  59. int64_t min = -dim_post_expr;
  60. int64_t max = dim_post_expr - 1;
  61. for (const auto i : c10::irange(ndims)) {
  62. auto& dim = dims[i];
  63. if (dim < min || dim > max) {
  64. TORCH_CHECK_INDEX(
  65. false,
  66. "Dimension out of range (expected to be in range of [",
  67. min,
  68. ", ",
  69. max,
  70. "], but got ",
  71. dim,
  72. ")");
  73. }
  74. if (dim < 0)
  75. dim += dim_post_expr;
  76. }
  77. }
  78. // Given a contiguous container of dimensions `dims`, this function "Wraps"
  79. // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
  80. // specified using negative indices.
  81. //
  82. // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
  83. // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
  84. // dimensions not in the range [-dim_post_expr, dim_post_expr).
  85. template <typename Container>
  86. inline void maybe_wrap_dims(
  87. Container& dims,
  88. int64_t dim_post_expr,
  89. bool wrap_scalars = true) {
  90. return maybe_wrap_dims_n(
  91. dims.data(), dims.size(), dim_post_expr, wrap_scalars);
  92. }
  93. // previously, size [0] tensors were the only possible empty tensors; thus, it
  94. // wasn't possible to cat empty tensors unless all the other tensors were
  95. // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
  96. // dimension behavior and dimension size checking). We maintain this behavior
  97. // for backwards compatibility, but only for this specific size (i.e. other
  98. // empty sizes are not skipped).
  99. inline int64_t legacy_cat_wrap_dim(
  100. int64_t dim,
  101. const std::vector<std::vector<int64_t>>& tensor_sizes) {
  102. for (auto& sizes : tensor_sizes) {
  103. if (sizes.size() == 1 && sizes[0] == 0) {
  104. continue;
  105. }
  106. return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
  107. }
  108. return dim;
  109. }
  110. inline int64_t legacy_cat_wrap_dim_symint(
  111. int64_t dim,
  112. const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
  113. for (auto& sizes : tensor_sizes) {
  114. if (sizes.size() == 1) {
  115. if (TORCH_GUARD_OR_FALSE(sizes[0].sym_eq(0))) {
  116. continue;
  117. }
  118. }
  119. return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
  120. }
  121. return dim;
  122. }
  123. inline int64_t legacy_cat_wrap_dim(
  124. int64_t dim,
  125. const MaterializedITensorListRef& tensors) {
  126. for (const Tensor& tensor : tensors) {
  127. if (tensor.dim() == 1) {
  128. if (TORCH_GUARD_OR_FALSE(tensor.sym_sizes()[0].sym_eq(0))) {
  129. continue;
  130. }
  131. }
  132. return maybe_wrap_dim(dim, tensor.dim());
  133. }
  134. return dim;
  135. }
  136. // wrap negative dims in a vector
  137. inline void wrap_all_dims(
  138. std::vector<int64_t>& dims_to_wrap,
  139. int64_t tensor_total_dims) {
  140. for (const auto i : c10::irange(dims_to_wrap.size())) {
  141. dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
  142. }
  143. }
  144. } // namespace at
  145. #else
  146. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  147. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)