FunctionalizeFallbackKernel.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/FunctionalStorageImpl.h>
  4. namespace at::functionalization {
  5. // `ViewMeta` implementation for `resize_` operation.
  6. struct TORCH_API resize__ViewMeta : public ViewMeta {
  7. FUNCTIONALIZATION_VIEWMETA_NAME(resize__ViewMeta)
  8. FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
  9. bool /* reapply_views */,
  10. const std::vector<int64_t>&);
  11. resize__ViewMeta(const SerializableTuple& tpl)
  12. : resize__ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
  13. resize__ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
  14. : ViewMeta(/*has_symbolic_inputs=*/false),
  15. reapply_views(reapply_views),
  16. size(size) {}
  17. Tensor forward(const Tensor& base) override;
  18. Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
  19. SerializableTuple to_serializable_tuple() {
  20. return std::make_tuple(reapply_views, size);
  21. }
  22. bool reapply_views;
  23. std::vector<int64_t> size;
  24. };
  25. // `ViewMeta` implementation for `_unsafe_view` operation.
  26. struct TORCH_API _unsafe_view_ViewMeta : public ViewMeta {
  27. FUNCTIONALIZATION_VIEWMETA_NAME(_unsafe_view_ViewMeta)
  28. FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
  29. bool /* has_symbolic_inputs */,
  30. const std::vector<c10::SymInt>&);
  31. _unsafe_view_ViewMeta(const SerializableTuple& tpl)
  32. : _unsafe_view_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
  33. _unsafe_view_ViewMeta(
  34. bool has_symbolic_inputs,
  35. const std::vector<c10::SymInt>& size)
  36. : ViewMeta(has_symbolic_inputs), size(size) {}
  37. Tensor forward(const Tensor& base) override;
  38. Tensor reverse(const Tensor& base, const Tensor& mutated_view) override;
  39. SerializableTuple to_serializable_tuple() {
  40. return std::make_tuple(has_symbolic_inputs, size);
  41. }
  42. std::vector<c10::SymInt> size;
  43. };
  44. } // namespace at::functionalization
  45. #else
  46. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  47. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)