_allocation.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import Optional, Protocol
  2. from contextvars import ContextVar
  3. class Buffer(Protocol):
  4. def data_ptr(self) -> int:
  5. ...
  6. class Allocator(Protocol):
  7. def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
  8. ...
  9. class NullAllocator:
  10. def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
  11. raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
  12. "Use triton.set_allocator to specify an allocator.")
  13. _NULL_ALLOCATOR = NullAllocator()
  14. _allocator: ContextVar[Allocator] = ContextVar("_allocator", default=_NULL_ALLOCATOR)
  15. def set_allocator(allocator: Allocator) -> None:
  16. """
  17. The allocator function is called during kernel launch for kernels that
  18. require additional global memory workspace.
  19. """
  20. _allocator.set(allocator)
  21. class _AllocatorWrapper:
  22. """
  23. Wrapper to provide ContextVar-like .get()/.set() methods. profile_allocator is
  24. used in same way as allocator so it is useful to maintain the interface.
  25. """
  26. def __init__(self, allocator: Allocator) -> None:
  27. self._allocator = allocator
  28. def get(self) -> Allocator:
  29. return self._allocator
  30. def set(self, allocator: Allocator) -> None:
  31. self._allocator = allocator
  32. def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
  33. return self._allocator(size, alignment, stream)
  34. _profile_allocator = _AllocatorWrapper(_NULL_ALLOCATOR)
  35. def set_profile_allocator(allocator: Optional[Allocator]) -> None:
  36. """
  37. The profile allocator function is called before kernel launch for kernels
  38. that require additional global memory workspace.
  39. """
  40. _profile_allocator.set(allocator if allocator is not None else _NULL_ALLOCATOR)