| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from typing import Optional, Protocol
- from contextvars import ContextVar
- class Buffer(Protocol):
- def data_ptr(self) -> int:
- ...
- class Allocator(Protocol):
- def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
- ...
- class NullAllocator:
- def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
- raise RuntimeError("Kernel requires a runtime memory allocation, but no allocator was set. " +
- "Use triton.set_allocator to specify an allocator.")
- _NULL_ALLOCATOR = NullAllocator()
- _allocator: ContextVar[Allocator] = ContextVar("_allocator", default=_NULL_ALLOCATOR)
- def set_allocator(allocator: Allocator) -> None:
- """
- The allocator function is called during kernel launch for kernels that
- require additional global memory workspace.
- """
- _allocator.set(allocator)
- class _AllocatorWrapper:
- """
- Wrapper to provide ContextVar-like .get()/.set() methods. profile_allocator is
- used in same way as allocator so it is useful to maintain the interface.
- """
- def __init__(self, allocator: Allocator) -> None:
- self._allocator = allocator
- def get(self) -> Allocator:
- return self._allocator
- def set(self, allocator: Allocator) -> None:
- self._allocator = allocator
- def __call__(self, size: int, alignment: int, stream: Optional[int]) -> Buffer:
- return self._allocator(size, alignment, stream)
- _profile_allocator = _AllocatorWrapper(_NULL_ALLOCATOR)
- def set_profile_allocator(allocator: Optional[Allocator]) -> None:
- """
- The profile allocator function is called before kernel launch for kernels
- that require additional global memory workspace.
- """
- _profile_allocator.set(allocator if allocator is not None else _NULL_ALLOCATOR)
|