XLAHooksInterface.h 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/Device.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Registry.h>
  6. #include <ATen/detail/AcceleratorHooksInterface.h>
  7. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
  8. namespace at {
  9. constexpr const char* XLA_HELP =
  10. "This error has occurred because you are trying "
  11. "to use some XLA functionality, but the XLA library has not been "
  12. "loaded by the dynamic linker. You must load xla libraries by `import torch_xla`";
  13. struct TORCH_API XLAHooksInterface : AcceleratorHooksInterface {
  14. ~XLAHooksInterface() override = default;
  15. void init() const override {
  16. TORCH_CHECK(false, "Cannot initialize XLA without torch_xla library. ", XLA_HELP);
  17. }
  18. virtual bool hasXLA() const {
  19. return false;
  20. }
  21. virtual std::string showConfig() const {
  22. TORCH_CHECK(
  23. false,
  24. "Cannot query detailed XLA version without torch_xla library. ",
  25. XLA_HELP);
  26. }
  27. const Generator& getDefaultGenerator(
  28. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  29. TORCH_CHECK(
  30. false, "Cannot get default XLA generator without torch_xla library. ", XLA_HELP);
  31. }
  32. Generator getNewGenerator(
  33. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  34. TORCH_CHECK(false, "Cannot get XLA generator without torch_xla library. ", XLA_HELP);
  35. }
  36. DeviceIndex getCurrentDevice() const override {
  37. TORCH_CHECK(false, "Cannot get current XLA device without torch_xla library. ", XLA_HELP);
  38. }
  39. Device getDeviceFromPtr(void* /*data*/) const override {
  40. TORCH_CHECK(false, "Cannot get device of pointer on XLA without torch_xla library. ", XLA_HELP);
  41. }
  42. Allocator* getPinnedMemoryAllocator() const override {
  43. TORCH_CHECK(false, "Cannot get XLA pinned memory allocator without torch_xla library. ", XLA_HELP);
  44. }
  45. bool isPinnedPtr(const void* data) const override {
  46. return false;
  47. }
  48. bool hasPrimaryContext(DeviceIndex device_index) const override {
  49. TORCH_CHECK(false, "Cannot query primary context without torch_xla library. ", XLA_HELP);
  50. }
  51. };
  52. struct TORCH_API XLAHooksArgs {};
  53. TORCH_DECLARE_REGISTRY(XLAHooksRegistry, XLAHooksInterface, XLAHooksArgs);
  54. #define REGISTER_XLA_HOOKS(clsname) \
  55. C10_REGISTER_CLASS(XLAHooksRegistry, clsname, clsname)
  56. namespace detail {
  57. TORCH_API const XLAHooksInterface& getXLAHooks();
  58. } // namespace detail
  59. } // namespace at
  60. C10_DIAGNOSTIC_POP()
  61. #else
  62. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  63. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)