MemPool.h 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Allocator.h>
  4. #include <c10/cuda/CUDACachingAllocator.h>
  5. #include <memory>
  6. namespace at::cuda {
  7. // Keep BC only
  8. using c10::CaptureId_t;
  9. using c10::MempoolId_t;
  10. // MemPool represents a pool of memory in a caching allocator. Currently,
  11. // it's just the ID of the pool object maintained in the CUDACachingAllocator.
  12. //
  13. // An allocator pointer can be passed to the MemPool to define how the
  14. // allocations should be done in the pool. For example: using a different
  15. // system allocator such as ncclMemAlloc.
  16. struct TORCH_CUDA_CPP_API MemPool {
  17. MemPool(
  18. std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator =
  19. nullptr,
  20. bool is_user_created = true,
  21. bool use_on_oom = false,
  22. bool no_split = false);
  23. MemPool(const MemPool&) = delete;
  24. MemPool(MemPool&&) = default;
  25. MemPool& operator=(const MemPool&) = delete;
  26. MemPool& operator=(MemPool&&) = default;
  27. ~MemPool();
  28. MempoolId_t id();
  29. int use_count();
  30. c10::DeviceIndex device();
  31. static MempoolId_t graph_pool_handle(bool is_user_created = true);
  32. private:
  33. static std::atomic<CaptureId_t> uid_;
  34. static std::atomic<CaptureId_t> uuid_;
  35. bool is_user_created_;
  36. MempoolId_t id_;
  37. c10::DeviceIndex device_;
  38. };
  39. } // namespace at::cuda
  40. #else
  41. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  42. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)