CUDAGraphsUtils.cuh 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/cuda/CUDAGeneratorImpl.h>
  4. #include <ATen/cuda/CUDAEvent.h>
  5. #include <ATen/cuda/PhiloxUtils.cuh>
  6. #include <ATen/cuda/detail/CUDAHooks.h>
  7. #include <ATen/detail/CUDAHooksInterface.h>
  8. #include <c10/core/StreamGuard.h>
  9. #include <c10/cuda/CUDAGraphsC10Utils.h>
  10. #include <c10/cuda/CUDAGuard.h>
  11. // c10/cuda/CUDAGraphsC10Utils.h has utils used by both c10 and aten.
  12. // This file adds utils used by aten only.
  13. namespace at::cuda {
  14. using CaptureId_t = c10::cuda::CaptureId_t;
  15. using CaptureStatus = c10::cuda::CaptureStatus;
  16. // Use this version where you don't want to create a CUDA context if none exists.
  17. inline CaptureStatus currentStreamCaptureStatus() {
  18. // don't create a context if we don't have to
  19. if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
  20. return c10::cuda::currentStreamCaptureStatusMayInitCtx();
  21. } else {
  22. return CaptureStatus::None;
  23. }
  24. }
  25. inline void assertNotCapturing(const std::string& attempt) {
  26. auto status = currentStreamCaptureStatus();
  27. TORCH_CHECK(status == CaptureStatus::None,
  28. attempt,
  29. " during CUDA graph capture. If you need this call to be captured, "
  30. "please file an issue. "
  31. "Current cudaStreamCaptureStatus: ",
  32. status);
  33. }
  34. inline void errorIfCapturingCudnnBenchmark(const std::string& version_specific) {
  35. auto status = currentStreamCaptureStatus();
  36. TORCH_CHECK(status == CaptureStatus::None,
  37. "Current cudaStreamCaptureStatus: ",
  38. status,
  39. "\nCapturing ",
  40. version_specific,
  41. "is prohibited. Possible causes of this error:\n"
  42. "1. No warmup iterations occurred before capture.\n"
  43. "2. The convolutions you're trying to capture use dynamic shapes, "
  44. "in which case capturing them is generally prohibited.");
  45. }
  46. } // namespace at::cuda
  47. #else
  48. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  49. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)