TensorMeta.h 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/DimVector.h>
  4. #include <ATen/core/Dimname.h>
  5. #include <c10/core/TensorOptions.h>
  6. #include <c10/util/strides.h>
  7. namespace at {
  8. class Tensor;
  9. namespace impl {
  10. // Use this to define the prototype for a meta function. There are two
  11. // versions; one that takes one argument (just the operator name), or FUNC2
  12. // variant that takes two arguments (operator name and overload name).
  13. //
  14. // Example usage:
  15. //
  16. // TORCH_META_FUNC2(add, Tensor) (
  17. // const Tensor& self, const Tensor& other
  18. // ) {
  19. // ... compute sizes and options ...
  20. // set_output(sizes, options);
  21. // }
  22. //
  23. #define TORCH_META_FUNC(name) void structured_##name::meta
  24. #define TORCH_META_FUNC2(name, overload) \
  25. void structured_##name##_##overload::meta
  26. // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct
  27. // as a return value. They should be used when the kernel in question has
  28. // precomputed values declared in native_functions.yaml and the corresponding
  29. // implementation should return an instance of the aforementioned struct.
  30. #define TORCH_PRECOMPUTE_META_FUNC(name) \
  31. structured_##name::meta_return_ty structured_##name::meta
  32. #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \
  33. structured_##name##_##overload::meta_return_ty \
  34. structured_##name##_##overload::meta
  35. // Use this to create a precompute struct in a meta function.
  36. #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<>
  37. #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \
  38. structured_##name##_##overload::precompute_out<>
  39. // Use this to define the prototype for an implementation. This takes only
  40. // one argument, which is the name of the dispatch key entry you're
  41. // implementing.
  42. //
  43. // Example usage:
  44. //
  45. // TORCH_IMPL_FUNC(add_cpu) (
  46. // Tensor& result, const Tensor& self, const Tensor& other
  47. // ) {
  48. // ... do the actual implementation ...
  49. // }
  50. //
  51. #define TORCH_IMPL_FUNC(name) void structured_##name::impl
  52. // Base class for all structured kernel classes. The set_output virtual
  53. // method is varied depending whether or not the operator is
  54. // functional/out/inplace, and could also be specialized for CPU/CUDA/etc
  55. // (although presently it isn't).
  56. //
  57. // A notable subclass of this interface is TensorIteratorBase.
  58. struct TORCH_API MetaBase {
  59. MetaBase() = default;
  60. MetaBase(const MetaBase&) = default;
  61. MetaBase& operator=(const MetaBase&) = default;
  62. MetaBase(MetaBase&&) noexcept = default;
  63. MetaBase& operator=(MetaBase&&) noexcept = default;
  64. virtual const Tensor& maybe_get_output(int64_t output_idx) = 0;
  65. // Note: [set_output_*]
  66. // See: https://github.com/pytorch/pytorch/issues/69813
  67. // Whenever defining the output properties in the META function of a
  68. // structured kernel (what was usually done with `set_output`), use one of
  69. // these 3 variants, instead. In order to decide which variant to use, check
  70. // the following decision tree:
  71. //
  72. // - Can the kernel you are going to implement support output tensors
  73. // with arbitrary strides?
  74. // |
  75. // -- YES: `set_output_raw_strided`
  76. // |
  77. // -- NO: Should the output tensor strides be contiguous?
  78. // |
  79. // -- YES: `set_output_contiguous`
  80. // |
  81. // -- NO: `set_output_strided`
  82. //
  83. // Use this function whenever the kernel requires specific strides for the
  84. // output. If `strides` does not match the given output strides, proxy outputs
  85. // will be created and passed to the IMPL function.
  86. virtual void set_output_strided(
  87. int64_t output_idx [[maybe_unused]],
  88. IntArrayRef sizes [[maybe_unused]],
  89. IntArrayRef strides [[maybe_unused]],
  90. TensorOptions options [[maybe_unused]],
  91. DimnameList names [[maybe_unused]] = {}) {
  92. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  93. }
  94. // Use this function whenever the kernel knows how to handle arbitrary strided
  95. // outputs. This function has the same behavior as the old `set_output`: it
  96. // will only re-stride if the given output was resized.
  97. virtual void set_output_raw_strided(
  98. int64_t output_idx [[maybe_unused]],
  99. IntArrayRef sizes [[maybe_unused]],
  100. IntArrayRef strides_hint [[maybe_unused]],
  101. TensorOptions options [[maybe_unused]],
  102. DimnameList names [[maybe_unused]] = {}) {
  103. TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented.");
  104. }
  105. // Use this function if the kernel requires contiguous strides.
  106. // Alias for `set_output_strided`, but with contiguous strides.
  107. void set_output_contiguous(
  108. int64_t output_idx,
  109. IntArrayRef sizes,
  110. TensorOptions options,
  111. DimnameList names = {}) {
  112. auto strides = c10::contiguous_strides(sizes);
  113. set_output_strided(output_idx, sizes, strides, options, names);
  114. }
  115. // Returns a reference to an undefined tensor if there is no presupplied
  116. // output
  117. const Tensor& maybe_get_output() {
  118. return maybe_get_output(0);
  119. }
  120. virtual ~MetaBase() = default;
  121. };
  122. } // namespace impl
  123. } // namespace at
  124. #else
  125. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  126. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)