CUDAGraph.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/Device.h>
  5. #include <c10/cuda/CUDACachingAllocator.h>
  6. #include <c10/cuda/CUDAGraphsC10Utils.h>
  7. #include <c10/cuda/CUDAStream.h>
  8. #include <c10/util/flat_hash_map.h>
  9. namespace at {
  10. struct Generator;
  11. struct CUDAGeneratorImpl;
  12. struct CUDAGeneratorState;
  13. namespace cuda {
  14. // Standalone way to get a unique mempool id usable as a pool=... argument
  15. // to CUDAGraph::capture_begin
  16. TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle();
  17. struct TORCH_CUDA_CPP_API CUDAGraph {
  18. CUDAGraph(bool keep_graph=false);
  19. ~CUDAGraph();
  20. // See Note [Explicit Registration of Generators to the CUDA Graph]
  21. void register_generator_state(c10::intrusive_ptr<at::CUDAGeneratorState> state);
  22. void register_generator_state(const at::Generator& generator);
  23. void capture_begin(
  24. MempoolId_t pool = {0, 0},
  25. cudaStreamCaptureMode capture_mode = cudaStreamCaptureModeGlobal);
  26. void capture_end();
  27. void instantiate();
  28. void replay();
  29. void reset();
  30. MempoolId_t pool();
  31. void enable_debug_mode();
  32. void debug_dump(const std::string& debug_path);
  33. cudaGraph_t raw_cuda_graph();
  34. cudaGraphExec_t raw_cuda_graph_exec();
  35. protected:
  36. cudaGraph_t graph_ = nullptr;
  37. cudaGraphExec_t graph_exec_ = nullptr;
  38. // internal states so reset() can do its best cleaning up
  39. // Set to true in capture_end if cudaStreamEndCapture succeeded
  40. // Set back to false after instantiate() unless keep_graph=True or
  41. // enable_debug_mode() was called on any CUDAGraph instance.
  42. bool has_graph_ = false;
  43. // Set to true in capture_end if cudaStreamEndCapture succeeded
  44. bool capture_ended_ = false;
  45. // Set to true in capture_end if cudaGraphInstantiate succeeded
  46. bool has_graph_exec_ = false;
  47. // the ID assigned by cuda during graph capture,
  48. // used to identify when a stream is participating in capture
  49. CaptureId_t capture_id_ = 0;
  50. // uuid used to request a particular private mempool from CUDACachingAllocator.
  51. // By default, this will be set to {id_, 0}.
  52. //
  53. // If capture_begin is called with "pool=other_graph.pool()", this graph's mempool_id_
  54. // will be set to the other graph's mempool_id_, and therefore share a mempool with the
  55. // other graph.
  56. //
  57. // If capture_begin is called with "pool=handle" where "handle" came from graph_pool_handle(),
  58. // it will share a mempool with any other captures that used "pool=handle".
  59. //
  60. // Sharing a mempool across graphs saves memory, and it's safe if you
  61. // know you'll replay those graphs in the same order you captured them.
  62. MempoolId_t mempool_id_;
  63. // Stream on which capture began
  64. at::cuda::CUDAStream capture_stream_;
  65. // multiple generator states and their wholegraph_increments in this graph
  66. // that are managed by the CUDA Graph
  67. ska::flat_hash_map<c10::intrusive_ptr<at::CUDAGeneratorState>, uint64_t>
  68. captured_generator_states_;
  69. // Device where capture occurred. Right now, for simplicity, we require all ops
  70. // in a capture to run on the same device, but this is a limitation of CUDAGraph,
  71. // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device
  72. // captures if needed.
  73. // init capture_dev_ as UNDEFINED_DEVICE to check that it stores the real device id in the destructor
  74. static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1;
  75. c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
  76. bool keep_graph_;
  77. };
  78. } // namespace cuda
  79. } // namespace at
  80. #else
  81. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  82. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)