MPSEvent.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright © 2023 Apple Inc.
  3. #pragma once
  4. #include <ATen/mps/MPSStream.h>
  5. #include <ctime>
  6. #include <stack>
  7. namespace at::mps {
  8. // NOTE: don't create instances of this class directly.
  9. // Use MPSEventPool to acquire instances of MPSEvent.
  10. class MPSEvent {
  11. public:
  12. explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
  13. ~MPSEvent();
  14. // records an event on the stream
  15. void record(bool needsLock, bool syncEvent = false);
  16. // makes all future work submitted to the stream wait for this event.
  17. bool wait(bool needsLock, bool syncEvent = false);
  18. // schedules a notifyListener callback for the event.
  19. bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
  20. // checks if events are already signaled.
  21. bool query() const;
  22. // blocks the CPU thread until all the GPU work that were scheduled
  23. // prior to recording this event are completed.
  24. bool synchronize();
  25. // resets this event with new parameters in case it gets reused from the event
  26. // pool
  27. void reset(MPSStream* stream, bool enable_timing);
  28. // returns the unique ID of the event instance
  29. id_t getID() const {
  30. return m_id;
  31. }
  32. // returns the completion timestamp of the event
  33. uint64_t getCompletionTime() const {
  34. return m_completion_time;
  35. }
  36. // if already recorded, waits for cpu_sync_cv to be signaled
  37. void waitForCpuSync();
  38. private:
  39. id_t m_id;
  40. // enables measuring the completion time of the notifyListener of this event
  41. bool m_enable_timing;
  42. uint64_t m_signalCounter = 0;
  43. MPSStream* m_stream = nullptr;
  44. MTLSharedEvent_t m_event = nullptr;
  45. MTLSharedEventListener* m_listener = nullptr;
  46. // used to sync the events created on this Stream with CPU
  47. std::mutex m_cpu_sync_mutex{};
  48. std::condition_variable m_cpu_sync_cv{};
  49. // CondVar predicate to sync the events created on this Stream with CPU
  50. bool m_cpu_sync_completed = false;
  51. // used to compute elapsed time
  52. uint64_t m_completion_time = 0;
  53. void recordLocked(bool syncEvent);
  54. bool waitLocked(bool syncEvent);
  55. bool notifyLocked(MTLSharedEventNotificationBlock block);
  56. void notifyCpuSync();
  57. static uint64_t getTime() {
  58. return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
  59. }
  60. };
  61. typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
  62. class MPSEventPool {
  63. public:
  64. explicit MPSEventPool(MPSStream* default_stream);
  65. ~MPSEventPool();
  66. MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
  67. void emptyCache();
  68. // these are mainly used for MPSHooks and torch.mps.Event() bindings
  69. id_t acquireEvent(bool enable_timing);
  70. void releaseEvent(id_t event_id);
  71. void recordEvent(id_t event_id, bool syncEvent);
  72. void waitForEvent(id_t event_id, bool syncEvent);
  73. void synchronizeEvent(id_t event_id);
  74. bool queryEvent(id_t event_id);
  75. // returns elapsed time between two recorded events in milliseconds
  76. double elapsedTime(id_t start_event_id, id_t end_event_id);
  77. private:
  78. MPSStream* m_default_stream = nullptr;
  79. std::recursive_mutex m_mutex;
  80. std::stack<std::unique_ptr<MPSEvent>> m_pool{};
  81. // dictionary to associate event IDs with event objects
  82. // used to retain in-use events out of the pool
  83. // for torch.mps.Event() bindings.
  84. std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
  85. uint64_t m_event_counter = 0;
  86. std::function<void(MPSEvent*)> m_default_deleter;
  87. MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
  88. };
  89. // shared_ptr is used to get MPSEventPool destroyed after dependent instances
  90. std::shared_ptr<MPSEventPool> getMPSEventPool();
  91. } // namespace at::mps
  92. #else
  93. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  94. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)