MPSProfiler.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright © 2022 Apple Inc.
  3. #pragma once
  4. #include <ATen/Tensor.h>
  5. #include <ATen/mps/MPSAllocatorInterface.h>
  6. #include <ATen/mps/MPSStream.h>
  7. #include <os/log.h>
  8. #include <os/signpost.h>
  9. #include <atomic>
  10. #include <ctime>
  11. #include <sstream>
  12. #include <string>
  13. #include <unordered_map>
  14. #include <utility>
  15. #ifndef __OBJC__
  16. typedef void* MTLCaptureManager;
  17. #endif
  18. namespace at::mps {
  19. namespace Profiler {
  20. struct BaseInfo {
  21. // profiling info types
  22. enum class Type {
  23. GRAPH,
  24. KERNEL,
  25. COPY,
  26. CPU_FALLBACK,
  27. };
  28. BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
  29. : type(infoType), profileId(Id), handle(Handle) {}
  30. virtual ~BaseInfo() = default;
  31. // type of profiling info
  32. Type type;
  33. // unique profile ID for execution instances of operations or copies
  34. uint64_t profileId;
  35. // ID generated by os_signpost
  36. // since it's possible to use event and interval-based signposts at the
  37. // same time, we need separate IDs for each.
  38. os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
  39. // accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
  40. // GPUStartTime")
  41. std::atomic<double> totalGpuTime{0.0};
  42. // accumulated Scheduling time in ms (obtained from CompletionHandler's
  43. // "KernelEndTime - KernelStartTime")
  44. std::atomic<double> totalSchedulingTime{0.0};
  45. // indicates if the operation or copy execution has completed
  46. std::atomic_bool completed{false};
  47. // handle used to identify the profile info's instance (usually the pointer)
  48. const uintptr_t handle;
  49. virtual const std::string toString(
  50. double gpuTime = 0,
  51. double schedulingTime = 0) const;
  52. // builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
  53. static std::string buildTensorString(
  54. const Tensor& tensor,
  55. bool includeBufferId = false);
  56. static uint64_t getTime() {
  57. return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
  58. }
  59. };
  60. struct OperationInfo : BaseInfo {
  61. OperationInfo(
  62. const void* Handle,
  63. bool IsGraph,
  64. uint64_t Id,
  65. const std::string& StrKey)
  66. : BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
  67. strKey(StrKey) {}
  68. uint64_t runCount = 0;
  69. std::string strKey;
  70. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  71. const override;
  72. // builds a string for a kernel
  73. static std::string buildKernelString(
  74. const std::string& kernelName,
  75. const TensorList& tensors,
  76. bool includeBufferId = false) {
  77. std::stringstream kernelStr;
  78. kernelStr << kernelName;
  79. for (const Tensor& tensor : tensors) {
  80. kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
  81. }
  82. return kernelStr.str();
  83. }
  84. };
  85. struct CpuFbInfo : BaseInfo {
  86. CpuFbInfo(uint64_t Id, const std::string& OpName)
  87. : BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
  88. uint64_t runCount = 0;
  89. // the current and total overhead of copies in bytes required to convert the
  90. // Op's input tensors from MPS to CPU and then output from CPU back to MPS
  91. size_t currentCopyOverhead = 0;
  92. size_t totalCopyOverhead = 0;
  93. std::string opName;
  94. std::string strKey;
  95. uint64_t startTime = 0;
  96. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  97. const override;
  98. void updateCopyOverhead(const TensorList& tensors) {
  99. currentCopyOverhead = 0;
  100. for (const Tensor& tensor : tensors) {
  101. if (tensor.defined()) {
  102. currentCopyOverhead += tensor.nbytes();
  103. }
  104. }
  105. totalCopyOverhead += currentCopyOverhead;
  106. }
  107. };
  108. struct CopyInfo : BaseInfo {
  109. enum class Kind {
  110. MPS_TO_MPS,
  111. MPS_TO_CPU,
  112. CPU_TO_MPS,
  113. };
  114. CopyInfo(
  115. const void* Handle,
  116. size_t Length,
  117. uint64_t Id,
  118. bool IsNonBlocking,
  119. bool UsesBlitter)
  120. : BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
  121. kind(Kind::MPS_TO_MPS),
  122. length(Length),
  123. isNonBlocking(IsNonBlocking),
  124. usesBlitter(UsesBlitter) {}
  125. Kind kind;
  126. size_t length;
  127. bool isNonBlocking;
  128. bool usesBlitter;
  129. std::string srcStrKey;
  130. std::string dstStrKey;
  131. // for copies that don't use blitters, we measure CPU time
  132. uint64_t startTime = 0;
  133. const std::string toString(double gpuTime = 0, double schedulingTime = 0)
  134. const override;
  135. static std::string buildTensorString(
  136. const void* buffer,
  137. const OptionalTensorRef tensor,
  138. bool includeBufferId = false);
  139. static bool isStorageOnMPS(
  140. const void* buffer,
  141. const OptionalTensorRef tensor) {
  142. if (tensor.has_value()) {
  143. return tensor->device().type() == at::kMPS;
  144. }
  145. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer);
  146. // getUnalignedBufferSize() returns -1 if input buffer is not on MPS device
  147. return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
  148. }
  149. static Kind getCopyKind(
  150. const void* srcBuffer,
  151. const void* dstBuffer,
  152. const OptionalTensorRef srcTensor,
  153. const OptionalTensorRef dstTensor) {
  154. const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
  155. const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
  156. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
  157. if (isSrcOnMPS && !isDstOnMPS) {
  158. return Kind::MPS_TO_CPU;
  159. } else if (!isSrcOnMPS && isDstOnMPS) {
  160. return Kind::CPU_TO_MPS;
  161. }
  162. return Kind::MPS_TO_MPS;
  163. }
  164. };
  165. struct CopyStat : CopyInfo {
  166. explicit CopyStat(std::string CopyKindStr)
  167. : CopyInfo(nullptr, 0, 0, false, false),
  168. kindStr(std::move(CopyKindStr)) {}
  169. // total number of copies
  170. size_t totalCount = 0;
  171. // number of Scalar copies (i.e., less than sizeof(int64))
  172. size_t scalarsCount = 0;
  173. // number of blocking copies (i.e., require syncing to GPU)
  174. size_t blockingCount = 0;
  175. // number of copies that used memcpy(), instead of Metal Blit Encoder
  176. size_t memcpyCount = 0;
  177. // accumulated GPU time in ms for the scalar copies
  178. std::atomic<double> scalarsGpuTime{0.0};
  179. // copy kind in string type
  180. std::string kindStr;
  181. };
  182. class MPSProfiler {
  183. public:
  184. // lower 16 bits used for profiler options
  185. enum ProfileOptions : uint32_t {
  186. OPTIONS_NONE = 0,
  187. // ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
  188. // etc.) (used for convenience to not compute bit flags by OR-ing manually)
  189. // trace all signpost types using events
  190. ALL_SIGNPOST_EVENTS = (1 << 0),
  191. // trace all signpost types using intervals
  192. ALL_SIGNPOST_INTERVALS = (1 << 1),
  193. // always wait for command buffer to finish executing after each commit
  194. WAIT_UNTIL_COMPLETED = (1 << 2),
  195. // for interval-based signposts, include the scheduling portion of
  196. // Graph/Kernel/Copy executions as well.
  197. // if flag is disable, only "GPU run time" is included in interval,
  198. // and not schedule time.
  199. INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
  200. // use these if you need to trace signposts types individually (rarely
  201. // required) trace signpost using intervals
  202. USE_INTERVALS = (1 << 4),
  203. // trace signpost by emitting events
  204. USE_EVENTS = (1 << 5),
  205. // used for sanity check (Change this when new option added)
  206. OPTIONS_COUNT = (USE_EVENTS << 1) - 1,
  207. };
  208. // when adding new types, #define the type string in MPSProfiler.mm as well.
  209. // upper 16 bits used for event types
  210. enum SignpostTypes : uint32_t {
  211. SIGNPOST_NONE = 0,
  212. // trace signposts for PyTorch operation executions
  213. RUN_OPERATION = (1 << 16),
  214. // trace signposts for blitter copies
  215. BLIT_COPY = (1 << 17),
  216. // trace signposts for ops that fall back on CPU
  217. CPU_FALLBACK = (1 << 18),
  218. // used for sanity check (Change this when new type added)
  219. SIGNPOST_COUNT = (CPU_FALLBACK << 1) - 1,
  220. };
  221. enum LogOptions : uint32_t {
  222. LOG_NONE = 0,
  223. // Info logging options during execution
  224. // -------------------------------------
  225. // prints operation info (id/key/run_count) during execution
  226. OPERATION_INFO = (1 << 0),
  227. // prints copy info (src/dst tensors/buffers, size, etc.) during execution
  228. COPY_INFO = (1 << 1),
  229. // prints CPU Fallback info (id/runCount/opName/copyOverhead) during
  230. // execution
  231. CPU_FALLBACK_INFO = (1 << 2),
  232. // Profiling Statistics logging options when process terminates
  233. // ------------------------------------------------------------
  234. // prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
  235. // process terminates this is convenient to not combine following stats bit
  236. // flags manually
  237. ALL_STATS = (1 << 3),
  238. // prints operation stats (GPU times, run count, etc.) before process
  239. // terminates
  240. OPERATION_STATS = (1 << 4),
  241. // prints copies stats (GPU times, copy kinds, sizes, etc.) before process
  242. // terminates
  243. COPY_STATS = (1 << 5),
  244. // prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
  245. // for tensors, etc.) before process terminates
  246. CPU_FALLBACK_STATS = (1 << 6),
  247. // Metadata format options when logging the info
  248. // ---------------------------------------------
  249. // if enabled, includes GPU run time in metadata (i.e.,
  250. // GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
  251. // ms])
  252. INCLUDE_GPU_TIME = (1 << 7),
  253. // if enabled, includes GPU scheduling time in metadata separately
  254. // (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
  255. // e.g., [GPU=0.324 ms, KRNL=0.036 ms]
  256. INCLUDE_KERNEL_TIME = (1 << 8),
  257. // if enabled, includes the unique buffer ID in metadata for the storage
  258. // of a tensor that was allocated on MPSAllocator. This is useful (along
  259. // with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
  260. // involved with various operations.
  261. INCLUDE_BUFFER_ID = (1 << 9),
  262. // used for sanity check (Change this when new option added)
  263. LOG_COUNT = (INCLUDE_BUFFER_ID << 1) - 1,
  264. };
  265. explicit MPSProfiler();
  266. ~MPSProfiler();
  267. // the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
  268. // Kernels the beginProfile*() functions return a profileId which is unique
  269. // per graph/kernel/copy
  270. uint64_t beginProfileKernel(
  271. const void* handle,
  272. const std::string& strKey,
  273. bool isGraph);
  274. uint64_t beginProfileKernel(
  275. const void* handle,
  276. const std::string& kernelName,
  277. const TensorList& tensors);
  278. uint64_t beginProfileCopy(
  279. const void* srcBuffer,
  280. const void* dstBuffer,
  281. const OptionalTensorRef srcTensor,
  282. const OptionalTensorRef dstTensor,
  283. size_t length,
  284. bool isNonBlocking,
  285. bool usesBlitter = true);
  286. uint64_t beginProfileCPUFallback(
  287. const std::string& opName,
  288. const TensorList& tensors);
  289. void beginProfileGPUInterval(const void* handle);
  290. void endProfileCopy(uint64_t profileId, SyncType syncType);
  291. void endProfileKernel(const void* handle, SyncType syncType = SyncType::NONE);
  292. void endProfileCPUFallback(const std::string& opName);
  293. // these are used to hook into Python bindings for torch.mps.profiler module.
  294. // this enables generating OS Signpost traces from MPSProfiler on-demand
  295. // during runtime (instead of environment variables).
  296. // The "mode" could be either "interval", "event", or both "interval,event"
  297. // for interval-based and/or event-based signpost tracing.
  298. void StartTrace(const std::string& mode, bool waitUntilCompleted);
  299. void StopTrace();
  300. // Abstractions for GPU trace capturing
  301. bool isCaptureEnabled() const;
  302. bool isCapturing() const;
  303. void startCapture(const std::string& name, MPSStream* stream = nullptr);
  304. void stopCapture(MPSStream* stream = nullptr);
  305. // convenience functions to indicate whether signpost tracing or
  306. // logging are enabled for the SignpostTypes
  307. bool isOperationProfilingEnabled() const {
  308. return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
  309. (m_log_options &
  310. (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
  311. }
  312. bool isCopyProfilingEnabled() const {
  313. return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
  314. (m_log_options & (LogOptions::COPY_INFO | LogOptions::COPY_STATS));
  315. }
  316. bool isCPUFallbackProfilingEnabled() const {
  317. return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
  318. (m_log_options &
  319. (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
  320. }
  321. bool isSignpostTracingEnabled() const {
  322. return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
  323. }
  324. private:
  325. // indicates what type of signpost types are enabled and traced by MPS
  326. // profiler.
  327. uint32_t m_signpost_types = 0;
  328. uint32_t m_profile_options = 0;
  329. uint32_t m_log_options = 0;
  330. uint64_t m_kernel_counter = 0;
  331. uint64_t m_graph_counter = 0;
  332. uint64_t m_cpu_fb_counter = 0;
  333. uint64_t m_copy_counter = 0;
  334. // technically, it's possible to trace both events and intervals at the same
  335. // time so we use separate os_log categories for them
  336. os_log_t m_os_log_events;
  337. os_log_t m_os_log_intervals;
  338. // stats logging could run either from destructor or signal handler
  339. // so this is used to check if logging has already started.
  340. std::atomic_bool hasLoggedStats{false};
  341. // indicates there are pending completionHandler callbacks that haven't been
  342. // called yet.
  343. std::atomic_bool hasPendingCompletionHandlers{false};
  344. // used to capture sigint signal to log profiling stats
  345. static struct sigaction currentSigint, previousSigint;
  346. // We use the following lists for two reasons:
  347. // 1- for interval-based signposts the "begin" point won't be in same function
  348. // as the "end" point where we need to be able to retrieve signpost's info
  349. // 2- if Operations info need to be logged when process ends using
  350. // LogOptions::OPERATION_INFO.
  351. // the pointer key for this map is either "MPSGraph*" or
  352. // "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
  353. // could be logged along with aggregate profiling numbers when the process
  354. // ends.
  355. std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
  356. m_op_info_list{};
  357. // the string key for this map is the op name that we fall back to execute on
  358. // CPU this list is retained and could be logged along with aggregate
  359. // profiling numbers when the process ends.
  360. std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
  361. m_cpu_fb_info_list{};
  362. // this list contains the info for copies, and its key is the unique profileId
  363. // which is generated from m_copy_counter
  364. // The copyInfo list is not retained.
  365. std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
  366. // a short list that contains copy stats
  367. std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
  368. m_copy_stat_list{};
  369. mutable MTLCaptureManager* captureManager = nil;
  370. unsigned captureCount = 0;
  371. void initialize();
  372. void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
  373. void endProfileExecution(
  374. BaseInfo& info,
  375. os_signpost_id_t event_signpost_id,
  376. os_signpost_id_t interval_signpost_id,
  377. double gpuTime,
  378. double schedulingTime);
  379. void addProfilerScheduledHandler(BaseInfo& info);
  380. void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
  381. void emitSignpostEvent(
  382. SignpostTypes signpost_type,
  383. os_signpost_id_t signpost_id,
  384. const std::string& msg) const;
  385. void beginSignpostInterval(
  386. SignpostTypes signpost_type,
  387. os_signpost_id_t signpost_id,
  388. const std::string& msg) const;
  389. void endSignpostInterval(
  390. SignpostTypes signpost_type,
  391. os_signpost_id_t signpost_id) const;
  392. void updateCopyStats(
  393. const CopyInfo& copyInfo,
  394. double gpuTime,
  395. double schedulingTime);
  396. // returns true if logging the profiling info "during the execution" is
  397. // enabled
  398. bool isProfileInfoLoggingEnabled(
  399. BaseInfo::Type infoType,
  400. bool isExecutionEnded);
  401. // logs all the profiling stats that are enabled
  402. void logProfilingStats();
  403. // logs kernel profiling stats when the process ends.
  404. void logOperationsProfilingStats(std::FILE* f) const;
  405. // logs CPU Fallback profiling stats when the process ends.
  406. void logCPUFallbackProfilingStats(std::FILE* f) const;
  407. // logs copy profiling stats when the process ends.
  408. void logCopyProfilingStats(std::FILE* f) const;
  409. os_signpost_id_t generateSignpostId(
  410. os_signpost_type_t signpostType,
  411. const void* ptr = nullptr);
  412. static SignpostTypes getSignpostType(BaseInfo::Type infoType);
  413. static void handleIntSignal(int signal);
  414. };
  415. } // namespace Profiler
  416. Profiler::MPSProfiler& getMPSProfiler();
  417. } // namespace at::mps
  418. #else
  419. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  420. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)