ScatterGatherChecks.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <vector>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/native/ReduceOpsUtils.h>
  6. #include <c10/util/irange.h>
  7. namespace at::native {
  8. namespace {
  9. // checks whether index.dtype == int64
  10. // and self.dtype == src.dtype if src is a Tensor
  11. inline void scatter_gather_dtype_check(
  12. const std::string& method_name,
  13. const Tensor& self,
  14. const Tensor& index,
  15. const std::optional<Tensor>& src_opt = std::nullopt
  16. ) {
  17. if (index.numel() != 0) {
  18. TORCH_CHECK(
  19. index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
  20. method_name, "(): Expected dtype int32/int64 for index"
  21. );
  22. }
  23. if (src_opt.has_value()) {
  24. const auto& src = src_opt.value();
  25. TORCH_CHECK(
  26. self.scalar_type() == src.scalar_type(),
  27. method_name, "(): Expected self.dtype to be equal to src.dtype"
  28. );
  29. }
  30. }
  31. // Used for `gather`-like methods
  32. // Note: self means the input tensor here
  33. // Test:
  34. // 1. index.size(d) <= self.size(d) for all d != dim
  35. // 2. index.dim() == self.dim()
  36. inline void gather_shape_check(const Tensor& self, int64_t dim,
  37. const Tensor& index
  38. ) {
  39. auto self_dims = ensure_nonempty_dim(self.dim());
  40. TORCH_CHECK(self_dims == ensure_nonempty_dim(index.dim()),
  41. "Index tensor must have the same number of dimensions as input tensor"
  42. );
  43. for (const auto i : c10::irange(self_dims)) {
  44. if (i != dim) {
  45. TORCH_CHECK(
  46. ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
  47. "Size does not match at dimension ", i,
  48. " expected index ", index.sizes(),
  49. " to be no larger than self ", self.sizes(),
  50. " apart from dimension ", dim
  51. );
  52. }
  53. }
  54. }
  55. // Used for `scatter` and `scatter_add`
  56. // Tests:
  57. // 1. index.size(d) <= self.size(d) for all d != dim
  58. // 2. index.size(d) <= src.size(d) for all d if src is a Tensor
  59. // 3. index.dim() == self.dim() == src.dim()
  60. inline void scatter_shape_check(
  61. const Tensor& self, int64_t dim, const Tensor& index,
  62. const std::optional<Tensor>& src_opt = std::nullopt
  63. ) {
  64. if (index.numel() == 0) return;
  65. TORCH_CHECK(
  66. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  67. "Index tensor must have the same number of dimensions as self tensor"
  68. );
  69. bool is_wrong_shape = false;
  70. int64_t self_dims = ensure_nonempty_dim(self.dim());
  71. // Check: index.size(d) <= self.size(d) for all d != dim
  72. for (const auto d : c10::irange(self_dims)) {
  73. int64_t index_d_size = ensure_nonempty_size(index, d);
  74. if (d == dim) continue;
  75. if (index_d_size > ensure_nonempty_size(self, d)) {
  76. is_wrong_shape = true;
  77. break;
  78. }
  79. }
  80. // Check: index.size(d) <= src.size(d) for all d if src is Tensor
  81. if (!is_wrong_shape && src_opt.has_value()) {
  82. const auto& src = src_opt.value();
  83. for (const auto d : c10::irange(self_dims)) {
  84. int64_t index_d_size = ensure_nonempty_size(index, d);
  85. if (index_d_size > ensure_nonempty_size(src, d)) {
  86. is_wrong_shape = true;
  87. break;
  88. }
  89. }
  90. }
  91. if (src_opt.has_value()) {
  92. const auto& src = src_opt.value();
  93. TORCH_CHECK(
  94. ensure_nonempty_dim(src.dim()) == ensure_nonempty_dim(index.dim()),
  95. "Index tensor must have the same number of dimensions as src tensor"
  96. );
  97. TORCH_CHECK(!is_wrong_shape,
  98. "Expected index ", index.sizes(),
  99. " to be no larger than self ", self.sizes(),
  100. " apart from dimension ", dim,
  101. " and to be no larger size than src ", src.sizes()
  102. );
  103. }
  104. else {
  105. TORCH_CHECK(!is_wrong_shape,
  106. "Expected index ", index.sizes(),
  107. " to be no larger than self ", self.sizes(),
  108. " apart from dimension ", dim
  109. );
  110. }
  111. }
  112. } // anonymous namespace
  113. } // namespace at::native
  114. #else
  115. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  116. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)