MPSGuardImpl.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright © 2022 Apple Inc.
  3. #pragma once
  4. #include <ATen/Context.h>
  5. #include <ATen/mps/MPSEvent.h>
  6. #include <ATen/mps/MPSStream.h>
  7. #include <c10/core/impl/DeviceGuardImplInterface.h>
  8. #include <c10/macros/Macros.h>
  9. #include <c10/util/Exception.h>
  10. #ifdef __OBJC__
  11. #include <Foundation/Foundation.h>
  12. #include <Metal/Metal.h>
  13. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  14. #endif
  15. #include <ATen/Tensor.h>
  16. #include <c10/core/MemoryFormat.h>
  17. #include <c10/core/Storage.h>
  18. #include <c10/core/TensorImpl.h>
  19. #include <c10/core/UndefinedTensorImpl.h>
  20. #include <c10/util/intrusive_ptr.h>
  21. #include <sys/_types/_size_t.h>
  22. #include <memory>
  23. namespace at::mps {
  24. typedef MPSEvent* mpsEvent_t;
  25. // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
  26. // https://github.com/pytorch/pytorch/issues/77170
  27. struct TORCH_API MPSGuardImpl final
  28. : public c10::impl::DeviceGuardImplInterface {
  29. static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
  30. // constructor
  31. MPSGuardImpl() {}
  32. explicit MPSGuardImpl(c10::DeviceType t) {
  33. TORCH_CHECK(
  34. t == DeviceType::MPS,
  35. "MPSGuardImpl initialized with non-MPS DeviceType: ",
  36. t);
  37. }
  38. // returns the type
  39. c10::DeviceType type() const override {
  40. return c10::DeviceType::MPS;
  41. }
  42. Device exchangeDevice(Device d) const override {
  43. return Device(c10::DeviceType::MPS, 0);
  44. }
  45. Device getDevice() const override {
  46. return Device(c10::DeviceType::MPS, 0);
  47. }
  48. std::optional<Device> uncheckedGetDevice() const noexcept {
  49. return Device(c10::DeviceType::MPS, 0);
  50. }
  51. void setDevice(Device d) const override {
  52. TORCH_CHECK(d.is_mps(), "Expected a MPS device, but got ", d);
  53. }
  54. void uncheckedSetDevice(Device d) const noexcept override {
  55. // TODO: Currently setting only device 0
  56. }
  57. Stream getStream(Device d) const override {
  58. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  59. }
  60. Stream getNewStream(Device, int priority = 0) const override {
  61. (void)priority;
  62. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  63. }
  64. Stream getDefaultStream(Device d) const override {
  65. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  66. }
  67. // NB: These do NOT set the current device
  68. Stream exchangeStream(Stream s) const override {
  69. return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  70. }
  71. DeviceIndex deviceCount() const noexcept override {
  72. if (at::hasMPS()) {
  73. // TODO: extend it for multi-device case
  74. return 1;
  75. } else {
  76. return 0;
  77. }
  78. }
  79. // Event-related functions
  80. void createEvent(mpsEvent_t* event, const EventFlag flag) const;
  81. void destroyEvent(void* event, const DeviceIndex device_index)
  82. const noexcept override;
  83. void record(
  84. void** event,
  85. const Stream& stream,
  86. const DeviceIndex device_index,
  87. const EventFlag flag) const override;
  88. void block(void* event, const Stream& stream) const override;
  89. bool queryEvent(void* event) const override;
  90. void synchronizeEvent(void* event) const override;
  91. double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
  92. const override;
  93. void synchronizeDevice(const DeviceIndex device_index) const override;
  94. };
  95. /// A variant of OptionalDeviceGuard that is specialized for MPS.
  96. struct OptionalMPSGuard {
  97. explicit OptionalMPSGuard() : guard_() {}
  98. explicit OptionalMPSGuard(std::optional<Device> device_opt)
  99. : guard_(device_opt) {}
  100. /// Set the current MPS device to the passed device index, if it is not
  101. /// nullopt
  102. explicit OptionalMPSGuard(std::optional<DeviceIndex> device_index_opt)
  103. : guard_(device_index_opt) {}
  104. // Copy is not allowed
  105. OptionalMPSGuard(const OptionalMPSGuard&) = delete;
  106. OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
  107. OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
  108. OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;
  109. /// Sets the MPS device to the given device, initializing the guard if it
  110. /// is not already initialized. Errors if the given device is not a MPS
  111. /// device.
  112. void set_device(Device device) {
  113. guard_.set_device(device);
  114. }
  115. /// Sets the MPS device to the given device, initializing the guard if it is
  116. /// not already initialized. Errors if the given device is not a MPS device.
  117. void reset_device(Device device) {
  118. guard_.reset_device(device);
  119. }
  120. /// Sets the MPS device to the given device index, initializing the guard if
  121. /// it is not already initialized.
  122. void set_index(DeviceIndex device_index) {
  123. guard_.set_index(device_index);
  124. }
  125. /// Returns the device that was set immediately prior to initialization of the
  126. /// guard, or nullopt if the guard is uninitialized.
  127. std::optional<Device> original_device() const {
  128. return guard_.original_device();
  129. }
  130. /// Returns the most recent device that was set using this device guard,
  131. /// either from construction, or via set_device, if the guard is initialized,
  132. /// or nullopt if the guard is uninitialized.
  133. std::optional<Device> current_device() const {
  134. return guard_.current_device();
  135. }
  136. /// Restore the original MPS device, resetting this guard to uninitialized
  137. /// state.
  138. void reset() {
  139. guard_.reset();
  140. }
  141. private:
  142. c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
  143. };
  144. C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
  145. } // namespace at::mps
  146. #else
  147. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  148. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)