RegisterFunctionalization.cpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
  2. // ${generated_comment}
  3. #include <ATen/core/LegacyTypeDispatch.h>
  4. #include <ATen/EmptyTensor.h>
  5. #include <ATen/FunctionalTensorWrapper.h>
  6. #include <ATen/ViewMetaClasses.h>
  7. #include <ATen/MemoryOverlap.h>
  8. #include <torch/library.h>
  9. #include <c10/util/env.h>
  10. #ifndef AT_PER_OPERATOR_HEADERS
  11. #include <ATen/Operators.h>
  12. #include <ATen/NativeFunctions.h>
  13. #else
  14. // needed for the meta tensor calls to get stride info in functionalization
  15. #include <ATen/ops/empty_strided_native.h>
  16. // needed for special handling of copy_().
  17. // See Note [functionalizating copy_() and not preserving strides]
  18. #include <ATen/ops/to_ops.h>
  19. #include <ATen/ops/expand_copy_ops.h>
  20. $ops_headers
  21. #endif
  22. namespace at {
  23. namespace functionalization {
  24. // This keyset is used by functionalization when it calls into meta kernels
  25. // to accurately propagate stride metadata.
  26. // Exclude any modes: the purpose of calling into meta kernels is only as an implementation
  27. // detail to perform shape inference, and we don't want any modal keys to run.
  28. // Specifically, we want to prevent functionalization and Python modes from running.
  29. constexpr auto exclude_keys_for_meta_dispatch =
  30. c10::functorch_transforms_ks |
  31. c10::DispatchKeySet({
  32. c10::DispatchKey::FuncTorchDynamicLayerBackMode,
  33. c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
  34. c10::DispatchKey::Python,
  35. c10::DispatchKey::PreDispatch,
  36. });
  37. // Helper around at::has_internal_overlap.
  38. // The ATen util is used in hot-path eager mode: it's always fast,
  39. // but might return TOO_HARD sometimes.
  40. // During functionalization, we're ok taking a bit longer
  41. // to detect memory overlap.
  42. inline bool has_internal_overlap_helper(const at::Tensor t) {
  43. auto has_overlap = at::has_internal_overlap(t);
  44. if (has_overlap == at::MemOverlap::Yes) return true;
  45. if (has_overlap == at::MemOverlap::No) return false;
  46. return false;
  47. }
  48. inline Tensor to_meta(const Tensor& t) {
  49. if (!t.defined()) return t;
  50. return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
  51. /*dtype=*/t.scalar_type(), /*layout=*/t.layout(),
  52. /*device=*/c10::Device(kMeta), /*pin_memory=*/std::nullopt);
  53. }
  54. inline std::optional<Tensor> to_meta(const std::optional<Tensor>& t) {
  55. if (t.has_value()) {
  56. return to_meta(*t);
  57. }
  58. return std::nullopt;
  59. }
  60. inline std::vector<Tensor> to_meta(at::ITensorListRef t_list) {
  61. std::vector<Tensor> outputs;
  62. outputs.reserve(t_list.size());
  63. for (const auto& tensor : t_list) {
  64. outputs.push_back(to_meta(tensor));
  65. }
  66. return outputs;
  67. }
  68. inline c10::List<Tensor> to_meta(const c10::List<Tensor>& t_list) {
  69. c10::List<Tensor> outputs;
  70. outputs.reserve(t_list.size());
  71. for (const auto i : c10::irange(t_list.size())) {
  72. outputs.push_back(to_meta(t_list[i]));
  73. }
  74. return outputs;
  75. }
  76. inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optional<Tensor>>& t_list) {
  77. c10::List<::std::optional<Tensor>> outputs;
  78. outputs.reserve(t_list.size());
  79. for (const auto i : c10::irange(t_list.size())) {
  80. outputs.push_back(to_meta(t_list[i]));
  81. }
  82. return outputs;
  83. }
  84. static bool disable_meta_reference() {
  85. static auto env = c10::utils::get_env("TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE");
  86. return env == "1";
  87. }
  88. ${func_definitions}
  89. } // namespace functionalization
  90. namespace {
  91. TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
  92. ${func_registrations};
  93. }
  94. } // namespace
  95. } // namespace at