LegacyVmapTransforms.h 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/LegacyBatchedTensorImpl.h>
  4. #include <ATen/core/IListRef.h>
  5. namespace at {
  6. // This file contains abstractions used for transforming *logical* vmap
  7. // arguments into *physical* arguments. (Keep reading for definitions of these
  8. // terms).
  9. // NOTE: [Logical vs physical args]
  10. // Consider the following vmap.
  11. // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
  12. // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
  13. // with batch dims 0 and 2:
  14. // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
  15. //
  16. // We say the *logical* view of the tensor has size [3] -- tensors inside
  17. // `func` appear to have size [3].
  18. // However, the *physical* underlying tensor (the one passed to vmap) has size
  19. // [2, 3, 4].
  20. //
  21. // This notion of logical vs physical also extends to non-tensor arguments.
  22. // Consider the previous tensor; let's assume the user called
  23. // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
  24. // dimension they are reducing over is dim 0 but the physical dim is dim 1
  25. // (the first non-batch dimension)
  26. // Forward declared; see NOTE: [What is a VmapPhysicalView?]
  27. struct VmapPhysicalView;
  28. // Most PyTorch operators take 4 or fewer inputs.
  29. constexpr int64_t kVmapTransformStaticInputSize = 4;
  30. using VmapPhysicalViewVec =
  31. SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
  32. // Pytorch generally advertises good performance for <= 5 dims.
  33. // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
  34. // dimensions to get 8. Adjust this number as necessary
  35. constexpr int64_t kVmapStaticDimVecSize = 8;
  36. using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
  37. using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
  38. // NOTE: [What is an VmapTransform?]
  39. // An *VmapTransform* converts logical views of tensors to physical views.
  40. //
  41. // Batching rules use VmapTransforms to convert logical arguments to
  42. // physical arguments, then call one or more at:: operator that handles the
  43. // physical arguments, and then converts the physical result back to a logical
  44. // argument.
  45. // VmapTransform for operators that take tensors with multiple batch dims.
  46. // Given one or more logical views on Tensors, `logicalToPhysical`
  47. // permutes all of the batch dims to the front of the tensor, aligns
  48. // and expands the batch dims to match each other (according to their `level`),
  49. // and returns a VmapPhysicalView on the tensor(s).
  50. struct TORCH_API MultiBatchVmapTransform {
  51. static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
  52. static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
  53. };
  54. // VmapTransform for operators that broadcast all inputs.
  55. // Given some logical views on Tensors, `logicalToPhysical`:
  56. // - permutes all of the batch dims to the front of the tensors
  57. // - aligns all the batch dims to the collective levels of all of the tensors.
  58. // If a tensor does not have a batch dim for a vmap level, then it receives
  59. // a size-one dimension for said level.
  60. // - aligns the non-batch dims to have the same dimensionality, adding extra
  61. // size-1 dimensions in between the batch dimensions and the non-batch
  62. // dimensions so that the batch dimensions are lined up from the right.
  63. //
  64. // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
  65. // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
  66. // tensors of size (B, 1, 2) and (B, 3, 2).
  67. //
  68. // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
  69. // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
  70. // actually *need* to return a tensor of size (1, 2) for the second tensor
  71. // because the broadcasting operation takes care of that for us, but we do
  72. // it anyways to keep things simple.
  73. struct TORCH_API BroadcastingVmapTransform {
  74. static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
  75. };
  76. // Forward declared, if you're reading this file head to toe, don't worry about
  77. // it yet.
  78. struct VmapPhysicalToLogicalMap;
  79. // NOTE: [What is a VmapPhysicalView?]
  80. // VmapPhysicalView represents a physical view on a Tensor.
  81. //
  82. // One can use it to further convert logical dimension indices, logical shapes,
  83. // and more to their physical variants, or convert a new (physical) tensor into
  84. // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
  85. //
  86. // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
  87. // the front and some levels that correspond to said batch dimensions.
  88. //
  89. // The levels bitset specifies which vmap levels correspond to the batch
  90. // dimensions at the front of the tensor. In particular, the number of set bits
  91. // corresponds to the number of batch dimensions on `tensor` and the rightmost
  92. // bit of `levels` specifies the maximum number of nested vmaps we are in at
  93. // this point in time.
  94. // For example, given:
  95. // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
  96. //
  97. // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
  98. // than or equal to 3.
  99. // bitset: 010100
  100. // ^
  101. // |
  102. // levels: 012345
  103. struct TORCH_API VmapPhysicalView {
  104. VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
  105. : levels_(levels), tensor_(std::move(tensor)) {
  106. TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_));
  107. }
  108. Tensor& tensor() {
  109. return tensor_;
  110. }
  111. const Tensor& tensor() const {
  112. return tensor_;
  113. }
  114. // Maps logical dim indices to physical dim indices. Also does dim wrapping.
  115. //
  116. // For example, given:
  117. // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
  118. //
  119. // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
  120. // This is because the size of levels tell us that the first two dimensions
  121. // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
  122. // a physical dim of `n + 2`.
  123. VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
  124. int64_t getPhysicalDim(int64_t logical_dim) const;
  125. // Returns a VmapPhysicalToLogicalMap object. This can be used for
  126. // mapping a physical tensor to a new logical tensor (BatchedTensor)
  127. VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
  128. // Maps a logical shape to a physical shape by prepending the batch
  129. // sizes to the logical shape.
  130. VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
  131. int64_t numBatchDims() const;
  132. private:
  133. int64_t numLogicalDims() const;
  134. std::bitset<kVmapNumLevels> levels_;
  135. Tensor tensor_;
  136. };
  137. // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
  138. // to a logical one (BatchedTensor). It holds some levels that are used to do
  139. // the mapping and assumes that the batch dimensions in the physical tensor all
  140. // occur at the front of the tensor.
  141. struct TORCH_API VmapPhysicalToLogicalMap {
  142. VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
  143. : levels_(levels) {}
  144. // Maps a physical tensor to a new logical tensor (BatchedTensor).
  145. // Assumes that all of the "batch dimensions" are at the front
  146. // of the physical tensor. For example, given:
  147. // - x = rank-4 Tensor with size 2, 3, 5, 7
  148. // - levels = (2, 4)
  149. // Returns:
  150. // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
  151. Tensor apply(const Tensor& physical_tensor) const;
  152. // Given a vector of physical tensors,
  153. // 1. maps each tensor to a new logical tensor. Assumes that all of the
  154. // "batch dimensions" are at the front of the physical tensors.
  155. // 2. stores the new logical tensors back into the passed-in vector. This is
  156. // to avoid additional dynamic allocations.
  157. void applyInplace(std::vector<Tensor>& physical_tensors) const;
  158. std::bitset<kVmapNumLevels> levels_;
  159. };
  160. } // namespace at
  161. #else
  162. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  163. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)