DTensorState.h 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Macros.h>
  4. namespace at {
  5. TORCH_API bool get_dtensor_allow_implicit_replication();
  6. TORCH_API void set_dtensor_allow_implicit_replication(bool enabled);
  7. struct DTensorAllowImplicitReplication {
  8. DTensorAllowImplicitReplication()
  9. : prev_dtensor_allow_implicit_replication_(
  10. get_dtensor_allow_implicit_replication()) {
  11. set_dtensor_allow_implicit_replication(true);
  12. }
  13. DTensorAllowImplicitReplication(const DTensorAllowImplicitReplication&) =
  14. delete;
  15. DTensorAllowImplicitReplication& operator=(
  16. const DTensorAllowImplicitReplication&) = delete;
  17. DTensorAllowImplicitReplication(DTensorAllowImplicitReplication&&) = delete;
  18. DTensorAllowImplicitReplication& operator=(
  19. DTensorAllowImplicitReplication&&) = delete;
  20. ~DTensorAllowImplicitReplication() {
  21. set_dtensor_allow_implicit_replication(
  22. prev_dtensor_allow_implicit_replication_);
  23. }
  24. private:
  25. bool prev_dtensor_allow_implicit_replication_;
  26. };
  27. } // namespace at
  28. #else
  29. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  30. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)