CachedTensorUtils.h 1.2 KB

1234567891011121314151617181920212223242526272829
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ATen.h>
  4. namespace at::caching {
  5. // Some systems (just cudagraphs currently) will persist a static tensor output
  6. // whose TensorImpl does not change across iterations. For these tensors caching
  7. // dtype conversions is invalid. Additionally, there will be an extra reference
  8. // count to these cached tensors that would prevent buffer inplacing and other
  9. // checks on tensor uniqueness. If we are not using these systems the enabled
  10. // flag will be false and we will avoid the hash lookup.
  11. TORCH_API bool is_cached_tensor(const at::Tensor& t);
  12. TORCH_API void add_cached_tensor(const at::Tensor& t);
  13. TORCH_API void remove_cached_tensor(const at::Tensor& t);
  14. TORCH_API void set_cached_tensors_enabled(bool enable);
  15. // For gradient buffer stealing we will adjust the use count of tensors
  16. // which are persisted by cudagraphs, just as we need to adjust reference
  17. // count of tensors with hooks.
  18. TORCH_API size_t adjusted_use_count(const at::Tensor& t);
  19. } // namespace at::caching
  20. #else
  21. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  22. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)