CUDAGraphsC10Utils.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/cuda/CUDAStream.h>
  4. #include <iostream>
  5. #include <utility>
  6. // CUDA Graphs utils used by c10 and aten.
  7. // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only.
  8. namespace c10::cuda {
  9. // RAII guard for "cudaStreamCaptureMode", a thread-local value
  10. // that controls the error-checking strictness of a capture.
  11. struct C10_CUDA_API CUDAStreamCaptureModeGuard {
  12. CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired)
  13. : strictness_(desired) {
  14. C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_));
  15. }
  16. CUDAStreamCaptureModeGuard(const CUDAStreamCaptureModeGuard&) = delete;
  17. CUDAStreamCaptureModeGuard(CUDAStreamCaptureModeGuard&&) = delete;
  18. CUDAStreamCaptureModeGuard& operator=(const CUDAStreamCaptureModeGuard&) =
  19. delete;
  20. CUDAStreamCaptureModeGuard& operator=(CUDAStreamCaptureModeGuard&&) = delete;
  21. ~CUDAStreamCaptureModeGuard() {
  22. C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_));
  23. }
  24. private:
  25. cudaStreamCaptureMode strictness_;
  26. };
  27. // Protects against enum cudaStreamCaptureStatus implementation changes.
  28. // Some compilers seem not to like static_assert without the messages.
  29. static_assert(
  30. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0,
  31. "unexpected int(cudaStreamCaptureStatusNone) value");
  32. static_assert(
  33. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1,
  34. "unexpected int(cudaStreamCaptureStatusActive) value");
  35. static_assert(
  36. int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2,
  37. "unexpected int(cudaStreamCaptureStatusInvalidated) value");
  38. enum class CaptureStatus : int {
  39. None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone),
  40. Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive),
  41. Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated)
  42. };
  43. inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) {
  44. switch (status) {
  45. case CaptureStatus::None:
  46. os << "cudaStreamCaptureStatusNone";
  47. break;
  48. case CaptureStatus::Active:
  49. os << "cudaStreamCaptureStatusActive";
  50. break;
  51. case CaptureStatus::Invalidated:
  52. os << "cudaStreamCaptureStatusInvalidated";
  53. break;
  54. default:
  55. TORCH_INTERNAL_ASSERT(
  56. false, "Unknown CUDA graph CaptureStatus", int(status));
  57. }
  58. return os;
  59. }
  60. // Use this version where you're sure a CUDA context exists already.
  61. inline CaptureStatus currentStreamCaptureStatusMayInitCtx() {
  62. cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone};
  63. C10_CUDA_CHECK(
  64. cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing));
  65. return CaptureStatus(is_capturing);
  66. }
  67. } // namespace c10::cuda
  68. #else
  69. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  70. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)