_gpu_trace.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from collections.abc import Callable
  2. from torch._utils import CallbackRegistry
  3. EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  4. "CUDA event creation"
  5. )
  6. EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  7. "CUDA event deletion"
  8. )
  9. EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  10. "CUDA event record"
  11. )
  12. EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry("CUDA event wait")
  13. MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  14. "CUDA memory allocation"
  15. )
  16. MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  17. "CUDA memory deallocation"
  18. )
  19. StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  20. "CUDA stream creation"
  21. )
  22. DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
  23. "CUDA device synchronization"
  24. )
  25. StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  26. "CUDA stream synchronization"
  27. )
  28. EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  29. "CUDA event synchronization"
  30. )
  31. def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
  32. EventCreationCallbacks.add_callback(cb)
  33. def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
  34. EventDeletionCallbacks.add_callback(cb)
  35. def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
  36. EventRecordCallbacks.add_callback(cb)
  37. def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
  38. EventWaitCallbacks.add_callback(cb)
  39. def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
  40. MemoryAllocationCallbacks.add_callback(cb)
  41. def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
  42. MemoryDeallocationCallbacks.add_callback(cb)
  43. def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
  44. StreamCreationCallbacks.add_callback(cb)
  45. def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
  46. DeviceSynchronizationCallbacks.add_callback(cb)
  47. def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
  48. StreamSynchronizationCallbacks.add_callback(cb)
  49. def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
  50. EventSynchronizationCallbacks.add_callback(cb)