LegacyVmapMode.h 1.2 KB

12345678910111213141516171819202122232425262728293031
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/impl/LocalDispatchKeySet.h>
  4. namespace at::impl {
  5. // VmapMode contains a thread local count of how many nested vmaps
  6. // we are currently inside. That number is known as the `vmap level`.
  7. // VmapMode is used in the implementation of the Python `torch.vmap` API.
  8. //
  9. // NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet.
  10. struct TORCH_API VmapMode {
  11. // Returns the vmap level, aka the count of how many nested vmaps we're in.
  12. static int64_t current_vmap_level();
  13. // Increment the count of nested vmaps. If this causes the vmap level to be
  14. // greater than 0, then it enables DispatchKey::VmapMode on all tensors.
  15. static int64_t increment_nesting();
  16. // Decrements the count of nested vmaps. If this causes the vmap level to be
  17. // equal to 0, then it disables DispatchKey::VmapMode on all tensors.
  18. static int64_t decrement_nesting();
  19. };
  20. } // namespace at::impl
  21. #else
  22. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  23. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)