TensorIndexing.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ExpandUtils.h>
  4. #include <ATen/ScalarOps.h>
  5. #include <ATen/core/Tensor.h>
  6. #include <ATen/core/TensorBody.h>
  7. #include <c10/core/SymInt.h>
  8. #include <c10/util/irange.h>
  9. #include <optional>
  10. #ifndef AT_PER_OPERATOR_HEADERS
  11. #include <ATen/Functions.h>
  12. #include <ATen/NativeFunctions.h>
  13. #else
  14. #include <ATen/ops/alias.h>
  15. #include <ATen/ops/empty.h>
  16. #include <ATen/ops/scalar_tensor.h>
  17. #include <ATen/ops/zeros.h>
  18. #endif
  19. #include <ATen/core/List.h>
  20. #include <utility>
  21. namespace at::indexing {
  22. constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
  23. constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
  24. enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
  25. constexpr std::nullopt_t None = std::nullopt;
  26. struct TORCH_API EllipsisIndexType final {
  27. EllipsisIndexType() = default;
  28. };
  29. TORCH_API extern const EllipsisIndexType Ellipsis;
  30. struct TORCH_API Slice final {
  31. public:
  32. Slice(
  33. std::optional<c10::SymInt> start_index = std::nullopt,
  34. std::optional<c10::SymInt> stop_index = std::nullopt,
  35. std::optional<c10::SymInt> step_index = std::nullopt) {
  36. if (!step_index.has_value()) {
  37. step_ = c10::SymInt(1);
  38. } else {
  39. step_ = std::move(step_index).value();
  40. }
  41. TORCH_CHECK_VALUE(
  42. step_.sym_ne(0).expect_true(__FILE__, __LINE__),
  43. "slice step cannot be zero");
  44. if (!start_index.has_value()) {
  45. start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
  46. } else {
  47. start_ = std::move(start_index).value();
  48. }
  49. if (!stop_index.has_value()) {
  50. stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
  51. } else {
  52. stop_ = std::move(stop_index).value();
  53. }
  54. }
  55. inline c10::SymInt start() const {
  56. return start_;
  57. }
  58. inline c10::SymInt stop() const {
  59. return stop_;
  60. }
  61. inline c10::SymInt step() const {
  62. return step_;
  63. }
  64. private:
  65. c10::SymInt start_;
  66. c10::SymInt stop_;
  67. c10::SymInt step_;
  68. };
  69. TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
  70. // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
  71. // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
  72. // into its equivalent `std::vector<TensorIndex>`, so that further tensor
  73. // indexing operations can be performed using the supplied indices.
  74. //
  75. // There is one-to-one correspondence between Python and C++ tensor index types:
  76. // Python | C++
  77. // -----------------------------------------------------
  78. // `None` | `at::indexing::None`
  79. // `Ellipsis` | `at::indexing::Ellipsis`
  80. // `...` | `"..."`
  81. // `123` | `123`
  82. // `True` / `False` | `true` / `false`
  83. // `:` | `Slice()` / `Slice(None, None)`
  84. // `::` | `Slice()` / `Slice(None, None, None)`
  85. // `1:` | `Slice(1, None)`
  86. // `1::` | `Slice(1, None, None)`
  87. // `:3` | `Slice(None, 3)`
  88. // `:3:` | `Slice(None, 3, None)`
  89. // `::2` | `Slice(None, None, 2)`
  90. // `1:3` | `Slice(1, 3)`
  91. // `1::2` | `Slice(1, None, 2)`
  92. // `:3:2` | `Slice(None, 3, 2)`
  93. // `1:3:2` | `Slice(1, 3, 2)`
  94. // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
  95. struct TORCH_API TensorIndex final {
  96. // Case 1: `at::indexing::None`
  97. TensorIndex(std::nullopt_t /*unused*/) : type_(TensorIndexType::None) {}
  98. // Case 2: "..." / `at::indexing::Ellipsis`
  99. TensorIndex(at::indexing::EllipsisIndexType /*unused*/)
  100. : type_(TensorIndexType::Ellipsis) {}
  101. TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
  102. TORCH_CHECK_VALUE(
  103. strcmp(str, "...") == 0,
  104. "Expected \"...\" to represent an ellipsis index, but got \"",
  105. str,
  106. "\"");
  107. }
  108. // Case 3: (Sym) Integer value
  109. TensorIndex(SymInt integer)
  110. : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
  111. TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
  112. TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
  113. // Case 4: Boolean value
  114. template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
  115. TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
  116. // Case 5: Slice represented in `at::indexing::Slice` form
  117. TensorIndex(Slice slice)
  118. : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
  119. // Case 6: Tensor value
  120. TensorIndex(Tensor tensor)
  121. : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
  122. inline bool is_none() const {
  123. return type_ == TensorIndexType::None;
  124. }
  125. inline bool is_ellipsis() const {
  126. return type_ == TensorIndexType::Ellipsis;
  127. }
  128. inline bool is_integer() const {
  129. return type_ == TensorIndexType::SymInt;
  130. }
  131. inline SymInt integer() const {
  132. return integer_;
  133. }
  134. inline bool is_boolean() const {
  135. return type_ == TensorIndexType::Boolean;
  136. }
  137. inline bool boolean() const {
  138. return boolean_;
  139. }
  140. inline bool is_slice() const {
  141. return type_ == TensorIndexType::Slice;
  142. }
  143. inline const Slice& slice() const {
  144. return slice_;
  145. }
  146. inline bool is_tensor() const {
  147. return type_ == TensorIndexType::Tensor;
  148. }
  149. inline const Tensor& tensor() const {
  150. return tensor_;
  151. }
  152. private:
  153. SymInt integer_ = 0;
  154. bool boolean_ = false;
  155. Slice slice_;
  156. Tensor tensor_;
  157. TensorIndexType type_;
  158. };
  159. TORCH_API std::ostream& operator<<(
  160. std::ostream& stream,
  161. const TensorIndex& tensor_index);
  162. TORCH_API std::ostream& operator<<(
  163. std::ostream& stream,
  164. const std::vector<TensorIndex>& tensor_indices);
  165. namespace impl {
  166. inline Tensor applySlice(
  167. const Tensor& self,
  168. int64_t dim,
  169. c10::SymInt start,
  170. c10::SymInt stop,
  171. c10::SymInt step,
  172. bool disable_slice_optimization,
  173. const at::Device& self_device,
  174. const std::optional<SymIntArrayRef>& self_sizes) {
  175. // TODO: implement negative step
  176. TORCH_CHECK_VALUE(
  177. step.sym_gt(0).expect_true(__FILE__, __LINE__),
  178. "step must be greater than zero");
  179. // See NOTE [nested tensor size for indexing]
  180. if (self_sizes.has_value() && !self_sizes.value().empty()) {
  181. // Skip this optimization if we are tracing, as the trace may be polymorphic
  182. // over the shape of the `self` tensor, and we still want to record
  183. // the slice.
  184. SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
  185. ? (*self_sizes)[dim]
  186. : self.sym_size(dim);
  187. if (!disable_slice_optimization &&
  188. TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) &&
  189. TORCH_STATICALLY_KNOWN_TRUE(length.sym_le(stop)) && step == 1) {
  190. return self;
  191. }
  192. }
  193. return self.slice_symint(
  194. dim, std::move(start), std::move(stop), std::move(step));
  195. }
  196. inline Tensor applySelect(
  197. const Tensor& self,
  198. int64_t dim,
  199. SymInt index,
  200. int64_t real_dim,
  201. const at::Device& /*self_device*/,
  202. const std::optional<SymIntArrayRef>& self_sizes) {
  203. // See NOTE [nested tensor size for indexing]
  204. if (self_sizes.has_value()) {
  205. auto maybe_index = index.maybe_as_int();
  206. if (maybe_index.has_value()) {
  207. TORCH_CHECK_INDEX(
  208. !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
  209. "invalid index of a 0-dim tensor. ",
  210. "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
  211. }
  212. auto size = (*self_sizes)[dim];
  213. // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
  214. // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
  215. // minus is undefined by the standard but in practice is equal to self. On
  216. // the other hand, indexing wrapping is valid for all negative int64_t
  217. // values, as x[INT64_MIN] is the same as x[INT64_MAX]
  218. TORCH_CHECK_INDEX(
  219. size.sym_gt(-1 - index)
  220. .sym_and(size.sym_gt(index))
  221. .expect_true(__FILE__, __LINE__),
  222. "index ",
  223. index,
  224. " is out of bounds for dimension ",
  225. real_dim,
  226. " with size ",
  227. size);
  228. }
  229. // if the index is negative, do not normalize it because that would fix the
  230. // index on the current tensor size in the tracer. aten::select also works on
  231. // negative indices
  232. return self.select_symint(dim, std::move(index));
  233. }
  234. inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
  235. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  236. // false as empty.
  237. if (value) {
  238. return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
  239. } else {
  240. return at::empty({0}, self.options().dtype(kLong));
  241. }
  242. }
  243. inline Tensor boolToIndexingTensorNonNativeDeviceType(
  244. const Tensor& self,
  245. bool value) {
  246. // booleans add a dimension of size 1. true indexes this dimension as if 0:,
  247. // false as empty.
  248. if (value) {
  249. return at::zeros({1}, self.options().dtype(kLong));
  250. } else {
  251. return at::empty({0}, self.options().dtype(kLong));
  252. }
  253. }
  254. inline Tensor boolToIndexingTensor(
  255. const Tensor& self,
  256. bool value,
  257. const at::Device& self_device) {
  258. if (self_device == at::kCPU || self_device == at::kCUDA) {
  259. return boolToIndexingTensorCPUOrCUDA(self, value);
  260. } else {
  261. return boolToIndexingTensorNonNativeDeviceType(self, value);
  262. }
  263. }
  264. inline Tensor scalarToTensorNonNativeDeviceType(
  265. const Scalar& v,
  266. const TensorOptions& options) {
  267. return at::scalar_tensor(v, options);
  268. }
  269. inline void recordTensorIndex(
  270. const Tensor& tensor,
  271. std::vector<Tensor>& outIndices,
  272. int64_t* dim_ptr) {
  273. if (outIndices.empty()) {
  274. outIndices.resize(*dim_ptr + 1);
  275. outIndices[*dim_ptr] = tensor;
  276. } else {
  277. outIndices.push_back(tensor);
  278. }
  279. if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
  280. *dim_ptr += tensor.dim();
  281. } else {
  282. *dim_ptr += 1;
  283. }
  284. }
  285. inline c10::List<::std::optional<Tensor>> typeConvertIndices(
  286. const Tensor& /*self*/,
  287. std::vector<Tensor>&& indices) {
  288. c10::List<::std::optional<Tensor>> converted_inds;
  289. converted_inds.reserve(indices.size());
  290. for (auto&& i : std::move(indices)) {
  291. converted_inds.push_back(std::move(i));
  292. }
  293. return converted_inds;
  294. }
  295. // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
  296. // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
  297. // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
  298. // indexing (i.e. it's called by `applySlicing` which is called by
  299. // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
  300. // than one dimension). If we were to merge the Python/C++
  301. // `count_specified_dimensions` function, on the Python side we would have to
  302. // construct a `std::vector` container to be consumed by the C++
  303. // `count_specified_dimensions` function, which adds 100s of nanoseconds
  304. // overhead and is undesirable.
  305. inline int64_t count_specified_dimensions(
  306. const ArrayRef<TensorIndex>& indices) {
  307. // Count the number of indexed dimensions (everything but ellipsis and None)
  308. int64_t count = 0;
  309. for (auto& obj : indices) {
  310. if (obj.is_tensor()) {
  311. auto& tensor = obj.tensor();
  312. if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
  313. count += tensor.dim();
  314. } else {
  315. count++;
  316. }
  317. } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
  318. count++;
  319. }
  320. }
  321. return count;
  322. }
  323. } // namespace impl
  324. // NOTE: Many functions below are only for consumption from Python indexing
  325. // implementation, they include:
  326. //
  327. // - `Tensor scalarToTensor(...)`
  328. // - `IntArrayRef slicePrefix1sSize(...)`
  329. // - `void copy_to(...)`
  330. // - `Tensor handleDimInMultiDimIndexing(...)`
  331. // - `Tensor dispatch_index(...)`
  332. // - `Tensor dispatch_index_put_(...)`
  333. // - `Tensor get_item(...)`
  334. // - `void set_item(...)`
  335. //
  336. // The rest of the functions are in `at::indexing::impl` namespace, signifying
  337. // that they shouldn't be used from Python indexing implementation.
  338. inline Tensor scalarToTensor(
  339. const Scalar& v,
  340. const TensorOptions& options,
  341. const at::Device& self_device) {
  342. if (self_device == at::kCPU && !v.isSymbolic()) {
  343. return at::detail::scalar_tensor_static(
  344. v,
  345. // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
  346. options.dtype_opt()->toScalarType(),
  347. self_device);
  348. } else {
  349. return impl::scalarToTensorNonNativeDeviceType(v, options);
  350. }
  351. }
  352. // To match numpy semantics:
  353. // As a special case for backwards compatibility,
  354. // strip away unit dimensions from the left of 'src'
  355. inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
  356. size_t first_non1_src = sizes.size();
  357. for (const auto i : c10::irange(sizes.size())) {
  358. // Unbacked SymInt has different behavior, but this is sound because
  359. // failing to slice will only ever cause an error, not divergent
  360. // behavior
  361. if (!sizes[i].has_hint() || sizes[i] != 1) {
  362. first_non1_src = i;
  363. break;
  364. }
  365. }
  366. return sizes.slice(first_non1_src);
  367. }
  368. inline void copy_to(const Tensor& dst, const Tensor& src) {
  369. if (dst.sym_sizes().equals(src.sym_sizes())) {
  370. // A shortcut to avoid generating hard-coded constant sizes during tracing.
  371. // This is not a perfect solution: when src & dst have different shapes,
  372. // constants will still appear. Users can workaround that case by
  373. // dst[index..] = src.reshape(..)
  374. dst.copy_(src);
  375. return;
  376. } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
  377. dst.fill_(src);
  378. return;
  379. }
  380. auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
  381. c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
  382. dst.copy_(*b_src);
  383. }
  384. // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
  385. // indexing functions from Python ]
  386. inline Tensor handleDimInMultiDimIndexing(
  387. const Tensor& prev_dim_result,
  388. const Tensor& original_tensor,
  389. const TensorIndex& index,
  390. int64_t* dim_ptr,
  391. int64_t* specified_dims_ptr,
  392. int64_t real_dim,
  393. std::vector<Tensor>& outIndices,
  394. bool disable_slice_optimization,
  395. const at::Device& original_tensor_device,
  396. const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
  397. if (index.is_integer()) {
  398. return impl::applySelect(
  399. prev_dim_result,
  400. *dim_ptr,
  401. index.integer(),
  402. real_dim,
  403. original_tensor_device,
  404. prev_dim_result_sizes);
  405. } else if (index.is_slice()) {
  406. Tensor result = impl::applySlice(
  407. prev_dim_result,
  408. *dim_ptr,
  409. index.slice().start(),
  410. index.slice().stop(),
  411. index.slice().step(),
  412. /*disable_slice_optimization=*/disable_slice_optimization,
  413. original_tensor_device,
  414. prev_dim_result_sizes);
  415. (*dim_ptr)++;
  416. if (!outIndices.empty()) {
  417. outIndices.resize(outIndices.size() + 1);
  418. }
  419. return result;
  420. } else if (index.is_ellipsis()) {
  421. auto ellipsis_ndims = original_tensor.dim() - *specified_dims_ptr;
  422. (*dim_ptr) += ellipsis_ndims;
  423. if (!outIndices.empty()) {
  424. outIndices.resize(outIndices.size() + ellipsis_ndims);
  425. }
  426. return prev_dim_result;
  427. } else if (index.is_none()) {
  428. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  429. (*dim_ptr)++;
  430. if (!outIndices.empty()) {
  431. outIndices.resize(outIndices.size() + 1);
  432. }
  433. return result;
  434. } else if (index.is_boolean()) {
  435. Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
  436. impl::recordTensorIndex(
  437. impl::boolToIndexingTensor(
  438. result, index.boolean(), original_tensor_device),
  439. outIndices,
  440. dim_ptr);
  441. return result;
  442. } else if (index.is_tensor()) {
  443. Tensor result = prev_dim_result;
  444. const Tensor& tensor = index.tensor();
  445. auto scalar_type = tensor.scalar_type();
  446. bool is_batched = tensor.key_set().has_any(c10::DispatchKeySet(
  447. {c10::DispatchKey::FuncTorchBatched,
  448. c10::DispatchKey::BatchedNestedTensor}));
  449. if (tensor.dim() == 0 && !is_batched &&
  450. at::isIntegralType(scalar_type, /*includeBool=*/true)) {
  451. if (scalar_type != at::kByte && scalar_type != at::kBool) {
  452. result = impl::applySelect(
  453. result,
  454. *dim_ptr,
  455. tensor.item<int64_t>(),
  456. real_dim,
  457. original_tensor_device,
  458. prev_dim_result_sizes);
  459. } else {
  460. result = result.unsqueeze(*dim_ptr);
  461. if (scalar_type == at::kBool) {
  462. impl::recordTensorIndex(
  463. impl::boolToIndexingTensor(
  464. result, tensor.item<bool>() != 0, original_tensor_device),
  465. outIndices,
  466. dim_ptr);
  467. } else {
  468. impl::recordTensorIndex(
  469. impl::boolToIndexingTensor(
  470. result, tensor.item<uint8_t>() != 0, original_tensor_device),
  471. outIndices,
  472. dim_ptr);
  473. }
  474. }
  475. } else {
  476. impl::recordTensorIndex(tensor, outIndices, dim_ptr);
  477. }
  478. return result;
  479. } else {
  480. TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
  481. }
  482. }
  483. namespace impl {
  484. // This mirrors `applySlicing` in
  485. // torch/csrc/autograd/python_variable_indexing.cpp
  486. inline Tensor applySlicing(
  487. const Tensor& self,
  488. const ArrayRef<TensorIndex>& indices,
  489. std::vector<Tensor>& outIndices,
  490. bool disable_slice_optimization,
  491. const at::Device& self_device,
  492. const std::optional<SymIntArrayRef>& self_sizes) {
  493. int64_t dim = 0;
  494. int64_t specified_dims = impl::count_specified_dimensions(indices);
  495. // See NOTE [nested tensor size for indexing]
  496. if (self_sizes.has_value()) {
  497. TORCH_CHECK_INDEX(
  498. specified_dims <= (int64_t)self_sizes->size(),
  499. "too many indices for tensor of dimension ",
  500. (int)self_sizes->size());
  501. }
  502. Tensor result = self;
  503. for (const auto i : c10::irange(indices.size())) {
  504. auto& obj = indices[i];
  505. // See NOTE [nested tensor size for indexing]
  506. std::optional<SymIntArrayRef> result_sizes = result.is_nested()
  507. ? std::optional<SymIntArrayRef>(std::nullopt)
  508. : std::optional<SymIntArrayRef>(result.sym_sizes());
  509. result = handleDimInMultiDimIndexing(
  510. /*prev_dim_result=*/result,
  511. /*original_tensor=*/self,
  512. /*index=*/obj,
  513. /*dim_ptr=*/&dim,
  514. /*specified_dims_ptr=*/&specified_dims,
  515. /*real_dim=*/static_cast<int64_t>(i),
  516. /*outIndices=*/outIndices,
  517. /*disable_slice_optimization=*/disable_slice_optimization,
  518. /*original_tensor_device=*/self_device,
  519. /*prev_dim_result_sizes=*/result_sizes);
  520. }
  521. return result;
  522. }
  523. } // namespace impl
  524. inline Tensor dispatch_index(
  525. const Tensor& self,
  526. std::vector<Tensor>&& indices) {
  527. // Remove trailing null elements from indices
  528. while (!indices.empty() && !indices.back().defined()) {
  529. indices.pop_back();
  530. }
  531. return self.index(impl::typeConvertIndices(self, std::move(indices)));
  532. }
  533. inline Tensor dispatch_index_put_(
  534. Tensor& self,
  535. std::vector<Tensor>&& indices,
  536. const Tensor& value) {
  537. // Remove trailing null elements from indices
  538. while (!indices.empty() && !indices.back().defined()) {
  539. indices.pop_back();
  540. }
  541. return self.index_put_(
  542. impl::typeConvertIndices(self, std::move(indices)), value);
  543. }
  544. // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
  545. // functions from Python ]
  546. //
  547. // Question: When should we set `disable_slice_optimization` to `true` when
  548. // calling C++ tensor indexing functions from Python indexing code?
  549. //
  550. // Answer: What "slice optimization" means: when we have a slicing expression
  551. // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
  552. // would skip dispatching the actual slice call as an optimization. However,
  553. // here are the cases where we DON'T want this optimization:
  554. //
  555. // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
  556. // Reason: we always return a shallow copy for expressions such as
  557. // `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
  558. // :]`, we return an alias of `tensor` by doing the following:
  559. // ```
  560. // Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
  561. // disable_slice_optimization, self_device, self_sizes); if
  562. // (tensorIndices.empty()) {
  563. // if (sliced.is_same(self)) {
  564. // // ensure we return a shallow copy for things like x[...]
  565. // sliced = at::alias(sliced);
  566. // }
  567. // return sliced;
  568. // }
  569. // ```)
  570. // 2. When we are doing JIT tracing.
  571. // Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
  572. // slice operation.
  573. // This mirrors `THPVariable_getitem` in
  574. // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
  575. // `disable_slice_optimization` when calling C++ tensor indexing functions from
  576. // Python ]
  577. inline Tensor get_item(
  578. const Tensor& self,
  579. const ArrayRef<TensorIndex>& indices,
  580. bool disable_slice_optimization = false) {
  581. at::Device self_device = self.device();
  582. // NOTE [nested tensor size for indexing]
  583. // nested tensor does not have a size (yet) so for now we represent its size
  584. // as null may need to be changed after we reach a better solution for nested
  585. // tensor size
  586. std::optional<SymIntArrayRef> self_sizes = self.is_nested()
  587. ? std::optional<SymIntArrayRef>(std::nullopt)
  588. : std::optional<SymIntArrayRef>(self.sym_sizes());
  589. // handle simple types: integers, slices, none, ellipsis, bool
  590. if (indices.size() == 1) {
  591. const TensorIndex& index = indices[0];
  592. if (index.is_integer()) {
  593. return impl::applySelect(
  594. self, 0, index.integer(), 0, self_device, self_sizes);
  595. } else if (index.is_slice()) {
  596. return impl::applySlice(
  597. self,
  598. 0,
  599. index.slice().start(),
  600. index.slice().stop(),
  601. index.slice().step(),
  602. /*disable_slice_optimization=*/true,
  603. self_device,
  604. self_sizes);
  605. } else if (index.is_none()) {
  606. return self.unsqueeze(0);
  607. } else if (index.is_ellipsis()) {
  608. return at::alias(self);
  609. } else if (index.is_boolean()) {
  610. Tensor result = self.unsqueeze(0);
  611. return dispatch_index(
  612. result,
  613. std::vector<Tensor>{impl::boolToIndexingTensor(
  614. result, index.boolean(), self_device)});
  615. }
  616. }
  617. std::vector<Tensor> tensorIndices;
  618. Tensor sliced = impl::applySlicing(
  619. self,
  620. indices,
  621. tensorIndices,
  622. disable_slice_optimization,
  623. self_device,
  624. self_sizes);
  625. if (tensorIndices.empty()) {
  626. if (sliced.is_same(self)) {
  627. // ensure we return a shallow copy for things like x[...]
  628. sliced = at::alias(sliced);
  629. }
  630. return sliced;
  631. }
  632. // indexing by tensors ("advanced" indexing)
  633. return dispatch_index(sliced, std::move(tensorIndices));
  634. }
  635. // This mirrors `THPVariable_setitem` in
  636. // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
  637. // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
  638. // tensor indexing functions from Python ]
  639. inline void set_item(
  640. const Tensor& self,
  641. const ArrayRef<TensorIndex>& indices,
  642. const Tensor& value,
  643. bool disable_slice_optimization = false) {
  644. at::Device self_device = self.device();
  645. SymIntArrayRef self_sizes = self.sym_sizes();
  646. // handle simple types: integers, slices, ellipsis, bool
  647. if (indices.size() == 1) {
  648. const TensorIndex& index = indices[0];
  649. if (index.is_boolean() && !index.boolean()) {
  650. // do nothing for false (technically we should check the size, but we
  651. // don't have real 0-sized shapes.
  652. return;
  653. } else if (index.is_ellipsis()) {
  654. copy_to(self, value);
  655. return;
  656. } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
  657. copy_to(self.unsqueeze(0), value);
  658. return;
  659. } else if (index.is_integer()) {
  660. copy_to(
  661. impl::applySelect(
  662. self, 0, index.integer(), 0, self_device, self_sizes),
  663. value);
  664. return;
  665. } else if (index.is_slice()) {
  666. copy_to(
  667. impl::applySlice(
  668. self,
  669. 0,
  670. index.slice().start(),
  671. index.slice().stop(),
  672. index.slice().step(),
  673. /*disable_slice_optimization=*/disable_slice_optimization,
  674. self_device,
  675. self_sizes),
  676. value);
  677. return;
  678. }
  679. }
  680. std::vector<Tensor> tensorIndices;
  681. Tensor sliced = impl::applySlicing(
  682. self,
  683. indices,
  684. tensorIndices,
  685. disable_slice_optimization,
  686. self_device,
  687. self_sizes);
  688. if (tensorIndices.empty()) {
  689. copy_to(sliced, value);
  690. return;
  691. }
  692. SymIntArrayRef valueSizes = value.sym_sizes();
  693. SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
  694. Tensor valuesSliced;
  695. if (!valueSizes.equals(slicedValueSizes)) {
  696. valuesSliced = value.view_symint(slicedValueSizes);
  697. } else {
  698. valuesSliced = value;
  699. }
  700. dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
  701. return;
  702. }
  703. } // namespace at::indexing
  704. #else
  705. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  706. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)