| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # mypy: allow-untyped-defs
- """Registry for flash attention implementations.
- This module contains the registration system for flash attention implementations.
- It has no torch dependencies to avoid circular imports during initialization.
- """
- import logging
- from collections.abc import Callable
- from typing import Literal, Protocol
- logger = logging.getLogger(__name__)
- class FlashAttentionHandle(Protocol):
- def remove(self) -> None: ...
- _RegisterFn = Callable[..., FlashAttentionHandle | None]
- _FlashAttentionImpl = Literal["FA3", "FA4"]
- _FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
- _FLASH_ATTENTION_ACTIVE: tuple[str, FlashAttentionHandle] | None = None
- def register_flash_attention_impl(
- impl: str | _FlashAttentionImpl,
- *,
- register_fn: _RegisterFn,
- ) -> None:
- """
- Register the callable that activates a flash attention impl.
- .. note::
- This function is intended for SDPA backend providers to register their
- implementations. End users should use :func:`activate_flash_attention_impl`
- to activate a registered implementation.
- Args:
- impl: Implementation identifier (e.g., ``"FA4"``).
- register_fn: Callable that performs the actual dispatcher registration.
- This function will be invoked by :func:`activate_flash_attention_impl`
- and should register custom kernels with the PyTorch dispatcher.
- It may optionally return a handle implementing
- :class:`FlashAttentionHandle` to keep any necessary state alive.
- Example:
- >>> def my_impl_register(module_path: str = "my_flash_impl"):
- ... # Register custom kernels with torch dispatcher
- ... pass # doctest: +SKIP
- >>> register_flash_attention_impl(
- ... "MyImpl", register_fn=my_impl_register
- ... ) # doctest: +SKIP
- """
- global _FLASH_ATTENTION_IMPLS
- _FLASH_ATTENTION_IMPLS[impl] = register_fn
- def activate_flash_attention_impl(
- impl: str | _FlashAttentionImpl,
- ) -> None:
- """
- Activate into the dispatcher a previously registered flash attention impl.
- .. note::
- Backend providers should NOT automatically activate their implementation
- on import. Users should explicitly opt-in by calling this function or via
- environment variables to ensure multiple provider libraries can coexist.
- Args:
- impl: Implementation identifier to activate. See
- :func:`~torch.nn.attention.list_flash_attention_impls` for available
- implementations.
- If the backend's :func:`register_flash_attention_impl` callable
- returns a :class:`FlashAttentionHandle`, the registry keeps that
- handle alive for the lifetime of the process (until explicit
- uninstall support exists).
- Example:
- >>> activate_flash_attention_impl("FA4") # doctest: +SKIP
- """
- global _FLASH_ATTENTION_ACTIVE, _FLASH_ATTENTION_IMPLS
- restore_flash_attention_impl(
- _raise_warn=False
- ) # first restore any prev overrides (if any) to default
- register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
- if register_fn is None:
- raise ValueError(
- f"Unknown flash attention impl '{impl}'. "
- f"Available implementations: {list_flash_attention_impls()}"
- )
- handle = register_fn()
- if handle is not None:
- _FLASH_ATTENTION_ACTIVE = (impl, handle)
- def list_flash_attention_impls() -> list[str]:
- """Return the names of all available flash attention implementations."""
- return sorted(_FLASH_ATTENTION_IMPLS.keys())
- def current_flash_attention_impl() -> str | None:
- """
- Return the currently activated flash attention impl name, if any.
- ``None`` indicates that no custom impl has been activated.
- """
- return (
- _FLASH_ATTENTION_ACTIVE[0]
- if _FLASH_ATTENTION_ACTIVE is not None
- else _FLASH_ATTENTION_ACTIVE
- )
- def restore_flash_attention_impl(_raise_warn: bool = True) -> None:
- """
- Restore the default FA2 implementation
- """
- global _FLASH_ATTENTION_ACTIVE
- handle = None
- if _FLASH_ATTENTION_ACTIVE is not None:
- handle = _FLASH_ATTENTION_ACTIVE[1]
- if handle is not None:
- handle.remove()
- elif _raise_warn:
- logger.warning(
- "Trying to restore default FA2 impl when no custom impl was activated"
- )
- _FLASH_ATTENTION_ACTIVE = None # default
|