LegacyBatchedTensorImpl.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <bitset>
  4. #include <ATen/ArrayRef.h>
  5. #include <ATen/SmallVector.h>
  6. #include <ATen/Tensor.h>
  7. namespace at {
  8. // We assume this in a few other places in the codebase,
  9. // but there isn't a centralized definition.
  10. constexpr int64_t kVmapMaxTensorDims = 64;
  11. // The valid vmap levels range from [0, 64). This effectively means that we
  12. // support a maximum of 64 nested vmaps.
  13. constexpr int64_t kVmapNumLevels = 64;
  14. // Store this number of elements of BatchDims on the stack. Most people will
  15. // probably use <= 5 nested vmaps, but adjust this number as necessary.
  16. constexpr int64_t kBatchDimsStackSize = 5;
  17. // a BatchDim represents a "private" dimension on a Tensor created inside of
  18. // vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension
  19. // is being vmap'ed over and the `level` being an identifier for which vmap
  20. // said dimension was created inside. The `dim` corresponds to a "physical
  21. // dim" - it is a dimension index on the underlying physical tensor that is
  22. // being vmapped over.
  23. struct BatchDim {
  24. BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {}
  25. int64_t dim() const {
  26. return dim_;
  27. }
  28. int64_t level() const {
  29. return level_;
  30. }
  31. private:
  32. int64_t dim_;
  33. int64_t level_;
  34. };
  35. using BatchDims = SmallVector<BatchDim, kBatchDimsStackSize>;
  36. using BatchDimsRef = ArrayRef<BatchDim>;
  37. // A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim
  38. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  39. // BatchedTensorImpl.
  40. //
  41. // The batch dimensions are treated as being "private"; they are not
  42. // user-visible. For example, in the following Tensor,
  43. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)])
  44. // dimensions 0 and 1 are batch dimensions.
  45. //
  46. // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
  47. // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7)
  48. // tensor.
  49. struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
  50. explicit BatchedTensorImpl(Tensor value, BatchDims bdims);
  51. // Returns a reference to BatchDims that represent which dimensions of this
  52. // tensor are private.
  53. BatchDimsRef bdims() const {
  54. return bdims_;
  55. }
  56. // BatchedTensorImpl wraps a Tensor
  57. const Tensor& value() const {
  58. return value_;
  59. }
  60. // Given a public dimension index, return the dimension index in the
  61. // underlying value() tensor. For example, if we have
  62. // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2,
  63. // dim=2)])
  64. // bt.actualDim(0) -> 1
  65. // bt.actualDim(1) -> 3
  66. // bt.actualDim(2) -> Error
  67. int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
  68. // We have to override this because we opted into CustomStrides
  69. IntArrayRef strides_custom() const override;
  70. // Override a bunch of methods inherited from TensorImpl to return error
  71. // messages.
  72. c10::SymBool sym_is_contiguous_custom(
  73. at::MemoryFormat memory_format) const override;
  74. void set_size(int64_t dim, int64_t new_size) override;
  75. void set_stride(int64_t dim, int64_t new_stride) override;
  76. void set_storage_offset(int64_t storage_offset) override;
  77. #ifdef DEBUG
  78. bool has_storage() const override;
  79. #endif
  80. private:
  81. // see NOTE: [BatchedTensorImpl levels invariant]
  82. void checkInvariants() const;
  83. const char* tensorimpl_type_name() const override;
  84. Tensor value_;
  85. // Note: [BatchedTensorImpl levels invariant]
  86. // There is an invariant that the BatchDims must be stored in increasing
  87. // `level` order. That is, for i < j, bdims_[i].level must be less than
  88. // bdims_[j].level.
  89. BatchDims bdims_;
  90. };
  91. // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
  92. // BatchedTensorImpl.
  93. inline bool isBatchedTensor(const Tensor& tensor) {
  94. return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched);
  95. }
  96. // It is unsafe to call this on a Tensor that is not backed by a
  97. // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
  98. inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
  99. return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
  100. }
  101. inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
  102. if (!isBatchedTensor(tensor)) {
  103. return nullptr;
  104. }
  105. return unsafeGetBatchedImpl(tensor);
  106. }
  107. // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
  108. inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(
  109. BatchDimsRef bdims) {
  110. std::bitset<kVmapMaxTensorDims> is_bdim;
  111. for (const auto& bdim : bdims) {
  112. is_bdim.set(bdim.dim());
  113. }
  114. return is_bdim;
  115. }
  116. // Creates a bitset for all of the levels present in `bdims`
  117. inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
  118. std::bitset<kVmapNumLevels> result;
  119. for (const auto& bdim : bdims) {
  120. result.set(bdim.level());
  121. }
  122. return result;
  123. }
  124. inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
  125. out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
  126. return out;
  127. }
  128. // Use this to construct a BatchedTensor from a regular Tensor
  129. TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
  130. // Adds a batch dim to `tensor`, returning a BatchedTensor
  131. TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
  132. // Checks if an inplace operation on self and other is "vmap compatible".
  133. // See NOTE: [vmap-incompatible in-place operations] for the definition of this.
  134. TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other);
  135. } // namespace at
  136. #else
  137. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  138. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)