XPUGraph.h 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/Tensor.h>
  4. #include <c10/core/Device.h>
  5. #include <c10/util/flat_hash_map.h>
  6. #include <c10/xpu/XPUCachingAllocator.h>
  7. #include <c10/xpu/XPUGraphsC10Utils.h>
  8. #include <c10/xpu/XPUStream.h>
  9. namespace at {
  10. struct Generator;
  11. struct XPUGeneratorState;
  12. namespace xpu {
  13. TORCH_XPU_API MempoolId_t graph_pool_handle();
  14. using xpuGraph_t = sycl::ext::oneapi::experimental::command_graph<
  15. sycl::ext::oneapi::experimental::graph_state::modifiable>;
  16. using xpuGraphExec_t = sycl::ext::oneapi::experimental::command_graph<
  17. sycl::ext::oneapi::experimental::graph_state::executable>;
  18. struct TORCH_XPU_API XPUGraph {
  19. XPUGraph(bool keep_graph = false);
  20. ~XPUGraph();
  21. void register_generator_state(
  22. c10::intrusive_ptr<at::XPUGeneratorState> state);
  23. void register_generator_state(const at::Generator& generator);
  24. void capture_begin(MempoolId_t pool = {0, 0});
  25. void capture_end();
  26. void instantiate();
  27. void replay();
  28. void reset();
  29. MempoolId_t pool();
  30. void enable_debug_mode();
  31. void debug_dump(const std::string& debug_path);
  32. xpuGraph_t* raw_xpu_graph();
  33. xpuGraphExec_t* raw_xpu_graph_exec();
  34. protected:
  35. std::unique_ptr<xpuGraph_t> graph_;
  36. std::unique_ptr<xpuGraphExec_t> graph_exec_;
  37. bool has_graph_ = false;
  38. bool capture_ended_ = false;
  39. bool has_graph_exec_ = false;
  40. MempoolId_t mempool_id_;
  41. at::xpu::XPUStream capture_stream_;
  42. // GeneratorState and whole graph offset increments mapping
  43. ska::flat_hash_map<c10::intrusive_ptr<at::XPUGeneratorState>, uint64_t>
  44. captured_generator_states_;
  45. static constexpr c10::DeviceIndex UNDEFINED_DEVICE = -1;
  46. c10::DeviceIndex capture_dev_{UNDEFINED_DEVICE};
  47. bool keep_graph_;
  48. };
  49. } // namespace xpu
  50. } // namespace at
  51. #else
  52. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  53. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)