| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- from collections.abc import Callable
- from torch._utils import CallbackRegistry
- EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event creation")
- EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry("XPU event deletion")
- EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
- "XPU event record"
- )
- EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry("XPU event wait")
- MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "XPU memory allocation"
- )
- MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "XPU memory deallocation"
- )
- StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "XPU stream creation"
- )
- DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
- "XPU device synchronization"
- )
- StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "XPU stream synchronization"
- )
- EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "XPU event synchronization"
- )
- def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
- EventCreationCallbacks.add_callback(cb)
- def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
- EventDeletionCallbacks.add_callback(cb)
- def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
- EventRecordCallbacks.add_callback(cb)
- def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
- EventWaitCallbacks.add_callback(cb)
- def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
- MemoryAllocationCallbacks.add_callback(cb)
- def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
- MemoryDeallocationCallbacks.add_callback(cb)
- def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
- StreamCreationCallbacks.add_callback(cb)
- def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
- DeviceSynchronizationCallbacks.add_callback(cb)
- def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
- StreamSynchronizationCallbacks.add_callback(cb)
- def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
- EventSynchronizationCallbacks.add_callback(cb)
|