MTIAHooksInterface.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/core/CachingDeviceAllocator.h>
  4. #include <c10/core/Device.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/core/Stream.h>
  7. #include <c10/util/Registry.h>
  8. #include <c10/core/Allocator.h>
  9. #include <ATen/detail/AcceleratorHooksInterface.h>
  10. #include <c10/util/python_stub.h>
  11. #include <string>
  12. namespace at {
  13. class Context;
  14. }
  15. namespace at {
  16. constexpr const char* MTIA_HELP =
  17. "The MTIA backend requires MTIA extension for PyTorch;"
  18. "this error has occurred because you are trying "
  19. "to use some MTIA's functionality without MTIA extension included.";
  20. struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
  21. // this fails the implementation if MTIAHooks functions are called, but
  22. // MTIA backend is not present.
  23. #define FAIL_MTIAHOOKS_FUNC(func) TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
  24. ~MTIAHooksInterface() override = default;
  25. void init() const override {
  26. // Avoid logging here, since MTIA needs init devices first then it will know
  27. // how many devices are available. Make it as no-op if mtia extension is not
  28. // dynamically loaded.
  29. return;
  30. }
  31. virtual bool hasMTIA() const {
  32. return false;
  33. }
  34. DeviceIndex deviceCount() const override {
  35. return 0;
  36. }
  37. virtual void deviceSynchronize(c10::DeviceIndex /*device_index*/) const {
  38. FAIL_MTIAHOOKS_FUNC(__func__);
  39. }
  40. virtual std::string showConfig() const {
  41. FAIL_MTIAHOOKS_FUNC(__func__);
  42. }
  43. bool hasPrimaryContext(DeviceIndex /*device_index*/) const override {
  44. return false;
  45. }
  46. void setCurrentDevice(DeviceIndex /*device*/) const override {
  47. FAIL_MTIAHOOKS_FUNC(__func__);
  48. }
  49. DeviceIndex getCurrentDevice() const override {
  50. FAIL_MTIAHOOKS_FUNC(__func__);
  51. return -1;
  52. }
  53. DeviceIndex exchangeDevice(DeviceIndex /*device*/) const override {
  54. FAIL_MTIAHOOKS_FUNC(__func__);
  55. return -1;
  56. }
  57. DeviceIndex maybeExchangeDevice(DeviceIndex /*device*/) const override {
  58. FAIL_MTIAHOOKS_FUNC(__func__);
  59. return -1;
  60. }
  61. virtual c10::Stream getCurrentStream(DeviceIndex /*device*/) const {
  62. FAIL_MTIAHOOKS_FUNC(__func__);
  63. return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  64. }
  65. virtual int64_t getCurrentRawStream(DeviceIndex /*device*/) const {
  66. FAIL_MTIAHOOKS_FUNC(__func__);
  67. return -1;
  68. }
  69. virtual c10::Stream getDefaultStream(DeviceIndex /*device*/) const {
  70. FAIL_MTIAHOOKS_FUNC(__func__);
  71. return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
  72. }
  73. virtual void setCurrentStream(const c10::Stream& /*stream*/) const {
  74. FAIL_MTIAHOOKS_FUNC(__func__);
  75. }
  76. bool isPinnedPtr(const void* /*data*/) const override {
  77. return false;
  78. }
  79. Allocator* getPinnedMemoryAllocator() const override {
  80. FAIL_MTIAHOOKS_FUNC(__func__);
  81. return nullptr;
  82. }
  83. virtual PyObject* memoryStats(DeviceIndex /*device*/) const {
  84. FAIL_MTIAHOOKS_FUNC(__func__);
  85. return nullptr;
  86. }
  87. virtual PyObject* getDeviceCapability(DeviceIndex /*device*/) const {
  88. FAIL_MTIAHOOKS_FUNC(__func__);
  89. return nullptr;
  90. }
  91. virtual PyObject* getDeviceProperties(DeviceIndex device) const {
  92. FAIL_MTIAHOOKS_FUNC(__func__);
  93. return nullptr;
  94. }
  95. virtual void emptyCache() const {
  96. FAIL_MTIAHOOKS_FUNC(__func__);
  97. }
  98. virtual void recordMemoryHistory(const std::optional<std::string>& /*enabled*/,
  99. const std::string& /*stacks*/,
  100. size_t /*max_entries*/) const {
  101. FAIL_MTIAHOOKS_FUNC(__func__);
  102. }
  103. virtual PyObject* memorySnapshot(const std::optional<std::string>& local_path) const {
  104. FAIL_MTIAHOOKS_FUNC(__func__);
  105. return nullptr;
  106. }
  107. virtual DeviceIndex getDeviceCount() const {
  108. FAIL_MTIAHOOKS_FUNC(__func__);
  109. return 0;
  110. }
  111. virtual void resetPeakMemoryStats(DeviceIndex /*device*/) const {
  112. FAIL_MTIAHOOKS_FUNC(__func__);
  113. }
  114. virtual void attachOutOfMemoryObserver(PyObject* observer) const {
  115. FAIL_MTIAHOOKS_FUNC(__func__);
  116. return;
  117. }
  118. bool isAvailable() const override;
  119. /* MTIAGraph related APIs */
  120. virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
  121. FAIL_MTIAHOOKS_FUNC(__func__);
  122. return -1;
  123. }
  124. virtual void mtiagraphDestroy(int64_t handle) const {
  125. FAIL_MTIAHOOKS_FUNC(__func__);
  126. }
  127. virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
  128. FAIL_MTIAHOOKS_FUNC(__func__);
  129. }
  130. virtual void mtiagraphCaptureEnd(int64_t handle) const {
  131. FAIL_MTIAHOOKS_FUNC(__func__);
  132. }
  133. virtual void mtiagraphInstantiate(int64_t handle) const {
  134. FAIL_MTIAHOOKS_FUNC(__func__);
  135. }
  136. virtual void mtiagraphReplay(int64_t handle) const {
  137. FAIL_MTIAHOOKS_FUNC(__func__);
  138. }
  139. virtual void mtiagraphReset(int64_t handle) const {
  140. FAIL_MTIAHOOKS_FUNC(__func__);
  141. }
  142. virtual MempoolId_t mtiagraphPool(int64_t handle) const {
  143. FAIL_MTIAHOOKS_FUNC(__func__);
  144. }
  145. virtual MempoolId_t graphPoolHandle() const {
  146. FAIL_MTIAHOOKS_FUNC(__func__);
  147. }
  148. const Generator& getDefaultGenerator(DeviceIndex /*device_index*/) const override {
  149. FAIL_MTIAHOOKS_FUNC(__func__);
  150. static Generator dummy_generator;
  151. return dummy_generator;
  152. }
  153. Generator getNewGenerator(DeviceIndex /*device_index*/) const override {
  154. FAIL_MTIAHOOKS_FUNC(__func__);
  155. static Generator dummy_generator;
  156. return dummy_generator;
  157. }
  158. };
  159. struct TORCH_API MTIAHooksArgs {};
  160. TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
  161. #define REGISTER_MTIA_HOOKS(clsname) C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
  162. namespace detail {
  163. TORCH_API const MTIAHooksInterface& getMTIAHooks();
  164. TORCH_API bool isMTIAHooksBuilt();
  165. } // namespace detail
  166. } // namespace at
  167. #else
  168. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  169. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)