XPUHooksInterface.h 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Device.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Registry.h>
  6. #include <ATen/detail/AcceleratorHooksInterface.h>
  7. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
  8. namespace at {
  9. namespace xpu {
  10. // Forward-declares at::xpu::LevelZero
  11. struct LevelZero;
  12. } // namespace xpu
  13. struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{
  14. ~XPUHooksInterface() override = default;
  15. void init() const override {
  16. TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library.");
  17. }
  18. virtual bool hasXPU() const {
  19. return false;
  20. }
  21. virtual std::string showConfig() const {
  22. TORCH_CHECK(
  23. false,
  24. "Cannot query detailed XPU version without ATen_xpu library.");
  25. }
  26. virtual int32_t getGlobalIdxFromDevice(const Device& device) const {
  27. TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
  28. }
  29. const Generator& getDefaultGenerator(
  30. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  31. TORCH_CHECK(
  32. false, "Cannot get default XPU generator without ATen_xpu library.");
  33. }
  34. Generator getNewGenerator(
  35. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  36. TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library.");
  37. }
  38. virtual DeviceIndex getNumGPUs() const {
  39. return 0;
  40. }
  41. virtual DeviceIndex current_device() const {
  42. TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
  43. }
  44. Device getDeviceFromPtr(void* /*data*/) const override {
  45. TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
  46. }
  47. virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
  48. TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
  49. }
  50. Allocator* getPinnedMemoryAllocator() const override {
  51. TORCH_CHECK(false, "Cannot get XPU pinned memory allocator without ATen_xpu library.");
  52. }
  53. bool isPinnedPtr(const void* data) const override {
  54. return false;
  55. }
  56. bool hasPrimaryContext(DeviceIndex device_index) const override {
  57. TORCH_CHECK(false, "Cannot query primary context without ATen_xpu library.");
  58. }
  59. virtual const at::xpu::LevelZero& level_zero() const {
  60. TORCH_CHECK(false, "Level zero requires XPU.");
  61. }
  62. };
  63. struct TORCH_API XPUHooksArgs {};
  64. TORCH_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs);
  65. #define REGISTER_XPU_HOOKS(clsname) \
  66. C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname)
  67. namespace detail {
  68. TORCH_API const XPUHooksInterface& getXPUHooks();
  69. } // namespace detail
  70. } // namespace at
  71. C10_DIAGNOSTIC_POP()
  72. #else
  73. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  74. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)