CUDAGreenContext.h 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/cuda/CUDAEvent.h>
  4. #include <cuda.h>
  5. // Forward declare green context as opaque ptr
  6. typedef struct CUgreenCtx_st* CUgreenCtx;
  7. namespace at::cuda {
  8. namespace {
  9. constexpr int kStreamPerGreenContextPool = 32;
  10. }
  11. class TORCH_CUDA_CPP_API GreenContext {
  12. public:
  13. // Green context creation
  14. static std::unique_ptr<GreenContext> create(
  15. uint32_t num_sms,
  16. std::optional<uint32_t> device_id);
  17. ~GreenContext() noexcept;
  18. // Delete copy constructor and assignment
  19. GreenContext(const GreenContext&) = delete;
  20. GreenContext& operator=(const GreenContext&) = delete;
  21. // Make this context current
  22. void setContext();
  23. void popContext();
  24. CUDAStream Stream();
  25. private:
  26. GreenContext(uint32_t device_id, uint32_t num_sms);
  27. // Implement move operations
  28. GreenContext(GreenContext&& other) noexcept;
  29. GreenContext& operator=(GreenContext&& other) noexcept;
  30. int32_t device_id_ = -1;
  31. CUgreenCtx green_ctx_ = nullptr;
  32. CUcontext context_ = nullptr;
  33. cudaStream_t parent_stream_ = nullptr;
  34. std::array<CUstream, kStreamPerGreenContextPool> green_ctx_streams_;
  35. std::atomic<int32_t> curr_stream_idx_ = -1;
  36. };
  37. } // namespace at::cuda
  38. #else
  39. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  40. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)