_gpu_trace.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from collections.abc import Callable
  2. from torch._utils import CallbackRegistry
  3. EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event creation")
  4. EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event deletion")
  5. EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  6. "XPU event record"
  7. )
  8. EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry("XPU event wait")
  9. MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  10. "XPU memory allocation"
  11. )
  12. MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  13. "XPU memory deallocation"
  14. )
  15. StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  16. "XPU stream creation"
  17. )
  18. DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
  19. "XPU device synchronization"
  20. )
  21. StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  22. "XPU stream synchronization"
  23. )
  24. EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  25. "XPU event synchronization"
  26. )
  27. def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
  28. EventCreationCallbacks.add_callback(cb)
  29. def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
  30. EventDeletionCallbacks.add_callback(cb)
  31. def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
  32. EventRecordCallbacks.add_callback(cb)
  33. def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
  34. EventWaitCallbacks.add_callback(cb)
  35. def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
  36. MemoryAllocationCallbacks.add_callback(cb)
  37. def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
  38. MemoryDeallocationCallbacks.add_callback(cb)
  39. def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
  40. StreamCreationCallbacks.add_callback(cb)
  41. def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
  42. DeviceSynchronizationCallbacks.add_callback(cb)
  43. def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
  44. StreamSynchronizationCallbacks.add_callback(cb)
  45. def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
  46. EventSynchronizationCallbacks.add_callback(cb)