| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <c10/core/Allocator.h>
- #include <c10/cuda/CUDACachingAllocator.h>
- #include <memory>
- namespace at::cuda {
- // Keep BC only
- using c10::CaptureId_t;
- using c10::MempoolId_t;
- // MemPool represents a pool of memory in a caching allocator. Currently,
- // it's just the ID of the pool object maintained in the CUDACachingAllocator.
- //
- // An allocator pointer can be passed to the MemPool to define how the
- // allocations should be done in the pool. For example: using a different
- // system allocator such as ncclMemAlloc.
- struct TORCH_CUDA_CPP_API MemPool {
- MemPool(
- std::shared_ptr<c10::cuda::CUDACachingAllocator::CUDAAllocator> allocator =
- nullptr,
- bool is_user_created = true,
- bool use_on_oom = false,
- bool no_split = false);
- MemPool(const MemPool&) = delete;
- MemPool(MemPool&&) = default;
- MemPool& operator=(const MemPool&) = delete;
- MemPool& operator=(MemPool&&) = default;
- ~MemPool();
- MempoolId_t id();
- int use_count();
- c10::DeviceIndex device();
- static MempoolId_t graph_pool_handle(bool is_user_created = true);
- private:
- static std::atomic<CaptureId_t> uid_;
- static std::atomic<CaptureId_t> uuid_;
- bool is_user_created_;
- MempoolId_t id_;
- c10::DeviceIndex device_;
- };
- } // namespace at::cuda
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|