MPSHooksInterface.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright © 2022 Apple Inc.
  3. #pragma once
  4. #include <ATen/detail/AcceleratorHooksInterface.h>
  5. #include <c10/core/Allocator.h>
  6. #include <c10/util/Exception.h>
  7. #include <c10/util/Registry.h>
  8. #include <cstddef>
  9. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
  10. namespace at {
  11. struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
  12. // this fails the implementation if MPSHooks functions are called, but
  13. // MPS backend is not present.
  14. #define FAIL_MPSHOOKS_FUNC(func) \
  15. TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
  16. ~MPSHooksInterface() override = default;
  17. // Initialize the MPS library state
  18. void init() const override {
  19. FAIL_MPSHOOKS_FUNC(__func__);
  20. }
  21. virtual bool hasMPS() const {
  22. return false;
  23. }
  24. virtual bool isOnMacOSorNewer(unsigned major = 13, unsigned minor = 0) const {
  25. FAIL_MPSHOOKS_FUNC(__func__);
  26. }
  27. const Generator& getDefaultGenerator(
  28. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  29. FAIL_MPSHOOKS_FUNC(__func__);
  30. }
  31. Generator getNewGenerator(
  32. [[maybe_unused]] DeviceIndex device_index) const override {
  33. FAIL_MPSHOOKS_FUNC(__func__);
  34. }
  35. virtual Allocator* getMPSDeviceAllocator() const {
  36. FAIL_MPSHOOKS_FUNC(__func__);
  37. }
  38. virtual void deviceSynchronize() const {
  39. FAIL_MPSHOOKS_FUNC(__func__);
  40. }
  41. virtual void commitStream() const {
  42. FAIL_MPSHOOKS_FUNC(__func__);
  43. }
  44. virtual void* getCommandBuffer() const {
  45. FAIL_MPSHOOKS_FUNC(__func__);
  46. }
  47. virtual void* getDispatchQueue() const {
  48. FAIL_MPSHOOKS_FUNC(__func__);
  49. }
  50. virtual void emptyCache() const {
  51. FAIL_MPSHOOKS_FUNC(__func__);
  52. }
  53. virtual size_t getCurrentAllocatedMemory() const {
  54. FAIL_MPSHOOKS_FUNC(__func__);
  55. }
  56. virtual size_t getDriverAllocatedMemory() const {
  57. FAIL_MPSHOOKS_FUNC(__func__);
  58. }
  59. virtual size_t getRecommendedMaxMemory() const {
  60. FAIL_MPSHOOKS_FUNC(__func__);
  61. }
  62. virtual void setMemoryFraction(double /*ratio*/) const {
  63. FAIL_MPSHOOKS_FUNC(__func__);
  64. }
  65. virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
  66. FAIL_MPSHOOKS_FUNC(__func__);
  67. }
  68. virtual void profilerStopTrace() const {
  69. FAIL_MPSHOOKS_FUNC(__func__);
  70. }
  71. virtual uint32_t acquireEvent(bool enable_timing) const {
  72. FAIL_MPSHOOKS_FUNC(__func__);
  73. }
  74. Device getDeviceFromPtr(void* data) const override {
  75. TORCH_CHECK(false, "Cannot get device of pointer on MPS without ATen_mps library. ");
  76. }
  77. virtual void releaseEvent(uint32_t event_id) const {
  78. FAIL_MPSHOOKS_FUNC(__func__);
  79. }
  80. virtual void recordEvent(uint32_t event_id) const {
  81. FAIL_MPSHOOKS_FUNC(__func__);
  82. }
  83. virtual void waitForEvent(uint32_t event_id) const {
  84. FAIL_MPSHOOKS_FUNC(__func__);
  85. }
  86. virtual void synchronizeEvent(uint32_t event_id) const {
  87. FAIL_MPSHOOKS_FUNC(__func__);
  88. }
  89. virtual bool queryEvent(uint32_t event_id) const {
  90. FAIL_MPSHOOKS_FUNC(__func__);
  91. }
  92. virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
  93. FAIL_MPSHOOKS_FUNC(__func__);
  94. }
  95. bool hasPrimaryContext(DeviceIndex device_index) const override {
  96. FAIL_MPSHOOKS_FUNC(__func__);
  97. }
  98. bool isPinnedPtr(const void* data) const override {
  99. return false;
  100. }
  101. Allocator* getPinnedMemoryAllocator() const override {
  102. FAIL_MPSHOOKS_FUNC(__func__);
  103. }
  104. #undef FAIL_MPSHOOKS_FUNC
  105. };
  106. struct TORCH_API MPSHooksArgs {};
  107. TORCH_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
  108. #define REGISTER_MPS_HOOKS(clsname) \
  109. C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
  110. namespace detail {
  111. TORCH_API const MPSHooksInterface& getMPSHooks();
  112. } // namespace detail
  113. } // namespace at
  114. C10_DIAGNOSTIC_POP()
  115. #else
  116. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  117. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)