FuncTorchTLS.h 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Macros.h>
  4. #include <memory>
  5. namespace at::functorch {
  6. // NOTE [functorch TLS in pytorch/pytorch]
  7. //
  8. // functorch lives out-of-tree. However, it has some TLS that needs to be
  9. // propagated. The solution for that is we store a pointer to the TLS
  10. // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
  11. // include whatever functorch needs.
  12. //
  13. // We need to store a pointer due to the indirection:
  14. // inside functorch, we will create a subclass of FunctorchTLSBase called
  15. // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
  16. // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
  17. // yet.
  18. //
  19. // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
  20. // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
  21. // We can't directly pass around FunctorchTLSBase (without a pointer) because
  22. // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
  23. // more elements.
  24. struct TORCH_API FuncTorchTLSBase {
  25. virtual ~FuncTorchTLSBase() = default;
  26. virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
  27. virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
  28. virtual void checkSupportsCppAutogradFunction() const = 0;
  29. virtual void checkSupportsInplaceRequiresGrad() const = 0;
  30. virtual void checkSupportsRetainGrad() const = 0;
  31. };
  32. // returns deepcopy of the functorch tls
  33. TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
  34. // sets the functorch tls. always does a deep copy.
  35. TORCH_API void setFuncTorchTLS(
  36. const std::shared_ptr<const FuncTorchTLSBase>& state);
  37. // get a mutable reference to the functorch tls
  38. TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
  39. } // namespace at::functorch
  40. #else
  41. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  42. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)