MPSStream.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright © 2022 Apple Inc.
  3. #pragma once
  4. #include <cstdint>
  5. #include <utility>
  6. #include <ATen/mps/MPSDevice.h>
  7. #include <c10/core/DeviceGuard.h>
  8. #include <c10/core/Stream.h>
  9. #include <c10/util/Exception.h>
  10. #ifdef __OBJC__
  11. #include <Foundation/Foundation.h>
  12. #include <Metal/Metal.h>
  13. #include <MetalPerformanceShaders/MetalPerformanceShaders.h>
  14. #include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
  15. typedef MPSCommandBuffer* MPSCommandBuffer_t;
  16. typedef id<MTLCommandQueue> MTLCommandQueue_t;
  17. typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
  18. typedef id<MTLSharedEvent> MTLSharedEvent_t;
  19. typedef id<MTLDevice> MTLDevice_t;
  20. typedef id<MTLBuffer> MTLBuffer_t;
  21. #else
  22. #include <dispatch/dispatch.h>
  23. typedef void* MPSCommandBuffer_t;
  24. typedef void* MPSGraph;
  25. typedef void* MPSGraphExecutionDescriptor;
  26. typedef void* MPSGraphCompilationDescriptor;
  27. typedef void* MTLCommandQueue_t;
  28. typedef void* MTLComputeCommandEncoder_t;
  29. typedef void* MTLSharedEvent_t;
  30. typedef void* MTLDevice_t;
  31. typedef void* MTLBuffer_t;
  32. typedef void* MTLCommandBufferHandler;
  33. typedef void* NSDictionary;
  34. #define nil NULL
  35. #endif
  36. namespace at::mps {
  37. //-----------------------------------------------------------------
  38. // MPSStream
  39. //-----------------------------------------------------------------
  40. enum class SyncType {
  41. NONE, // no commit to command buffer
  42. COMMIT, // commit and flush the command buffer
  43. COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
  44. COMMIT_AND_CONTINUE, // commit and continue with a new underlying command buffer
  45. COMMIT_ADAPTIVE, // commit adaptively based on available memory
  46. };
  47. class TORCH_API MPSStream {
  48. public:
  49. enum Unchecked { UNCHECKED };
  50. /// Construct a MPSStream from a Stream. This construction is checked,
  51. /// and will raise an error if the Stream is not, in fact, a MPS stream.
  52. explicit MPSStream(Stream stream);
  53. ~MPSStream();
  54. MTLCommandQueue_t commandQueue() const {
  55. return _commandQueue;
  56. }
  57. dispatch_queue_t queue() const {
  58. return _serialQueue;
  59. }
  60. MPSCommandBuffer_t commandBuffer();
  61. MTLComputeCommandEncoder_t commandEncoder();
  62. void endKernelCoalescing();
  63. void synchronize(SyncType syncType);
  64. void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
  65. void copy(MTLBuffer_t srcBuffer,
  66. MTLBuffer_t dstBuffer,
  67. size_t length,
  68. size_t srcOffset,
  69. size_t dstOffset,
  70. uint64_t profileId,
  71. SyncType syncType = SyncType::NONE);
  72. void copy_and_sync(MTLBuffer_t srcBuffer,
  73. MTLBuffer_t dstBuffer,
  74. size_t length,
  75. size_t srcOffset,
  76. size_t dstOffset,
  77. bool non_blocking,
  78. uint64_t profileId);
  79. void executeMPSGraph(MPSGraph* mpsGraph,
  80. NSDictionary* feeds,
  81. NSDictionary* results,
  82. SyncType syncType = SyncType::NONE);
  83. void addCompletedHandler(MTLCommandBufferHandler block);
  84. /// Get the MPS device index that this stream is associated with.
  85. c10::DeviceIndex device_index() const {
  86. return _stream.device_index();
  87. }
  88. MTLCommandQueue_t stream() const {
  89. return _commandQueue;
  90. }
  91. MTLDevice_t device() const;
  92. /// Explicit conversion to Stream.
  93. Stream unwrap() const {
  94. return _stream;
  95. }
  96. MTLBuffer_t getErrorBuffer();
  97. void checkLastError();
  98. private:
  99. Stream _stream;
  100. MTLCommandQueue_t _commandQueue = nil;
  101. MPSCommandBuffer_t _commandBuffer = nil;
  102. MPSCommandBuffer_t _prevCommandBuffer = nil;
  103. MTLComputeCommandEncoder_t _commandEncoder = nil;
  104. MPSGraphExecutionDescriptor* _executionDescriptor = nil;
  105. MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
  106. dispatch_queue_t _serialQueue = nullptr;
  107. // CommitAndContinue is enabled by default
  108. bool _enableCommitAndContinue = true;
  109. // Buffer that contains last raised error
  110. MTLBuffer_t _errorBuffer = nil;
  111. // use synchronize() to access any of these commit functions outside MPSStream
  112. void commit();
  113. void commitAndWait();
  114. void commitAndContinue();
  115. void flush();
  116. };
  117. /**
  118. * Get the current MPS stream
  119. */
  120. TORCH_API MPSStream* getCurrentMPSStream();
  121. /**
  122. * Get the default MPS stream
  123. */
  124. TORCH_API MPSStream* getDefaultMPSStream();
  125. //-----------------------------------------------------------------
  126. // MPSStreamImpl
  127. //-----------------------------------------------------------------
  128. class TORCH_API MPSStreamImpl {
  129. public:
  130. /**
  131. * Gets single instance of the MPSStream.
  132. */
  133. static MPSStream* getInstance();
  134. private:
  135. static MPSStream* _stream;
  136. MPSStreamImpl();
  137. };
  138. #ifdef __OBJC__
  139. void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
  140. #endif
  141. } // namespace at::mps
  142. #else
  143. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  144. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)