_registry.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # mypy: allow-untyped-defs
  2. """Registry for flash attention implementations.
  3. This module contains the registration system for flash attention implementations.
  4. It has no torch dependencies to avoid circular imports during initialization.
  5. """
  6. import logging
  7. from collections.abc import Callable
  8. from typing import Literal, Protocol
  9. logger = logging.getLogger(__name__)
  10. class FlashAttentionHandle(Protocol):
  11. def remove(self) -> None: ...
  12. _RegisterFn = Callable[..., FlashAttentionHandle | None]
  13. _FlashAttentionImpl = Literal["FA3", "FA4"]
  14. _FLASH_ATTENTION_IMPLS: dict[str, _RegisterFn] = {}
  15. _FLASH_ATTENTION_ACTIVE: tuple[str, FlashAttentionHandle] | None = None
  16. def register_flash_attention_impl(
  17. impl: str | _FlashAttentionImpl,
  18. *,
  19. register_fn: _RegisterFn,
  20. ) -> None:
  21. """
  22. Register the callable that activates a flash attention impl.
  23. .. note::
  24. This function is intended for SDPA backend providers to register their
  25. implementations. End users should use :func:`activate_flash_attention_impl`
  26. to activate a registered implementation.
  27. Args:
  28. impl: Implementation identifier (e.g., ``"FA4"``).
  29. register_fn: Callable that performs the actual dispatcher registration.
  30. This function will be invoked by :func:`activate_flash_attention_impl`
  31. and should register custom kernels with the PyTorch dispatcher.
  32. It may optionally return a handle implementing
  33. :class:`FlashAttentionHandle` to keep any necessary state alive.
  34. Example:
  35. >>> def my_impl_register(module_path: str = "my_flash_impl"):
  36. ... # Register custom kernels with torch dispatcher
  37. ... pass # doctest: +SKIP
  38. >>> register_flash_attention_impl(
  39. ... "MyImpl", register_fn=my_impl_register
  40. ... ) # doctest: +SKIP
  41. """
  42. global _FLASH_ATTENTION_IMPLS
  43. _FLASH_ATTENTION_IMPLS[impl] = register_fn
  44. def activate_flash_attention_impl(
  45. impl: str | _FlashAttentionImpl,
  46. ) -> None:
  47. """
  48. Activate into the dispatcher a previously registered flash attention impl.
  49. .. note::
  50. Backend providers should NOT automatically activate their implementation
  51. on import. Users should explicitly opt-in by calling this function or via
  52. environment variables to ensure multiple provider libraries can coexist.
  53. Args:
  54. impl: Implementation identifier to activate. See
  55. :func:`~torch.nn.attention.list_flash_attention_impls` for available
  56. implementations.
  57. If the backend's :func:`register_flash_attention_impl` callable
  58. returns a :class:`FlashAttentionHandle`, the registry keeps that
  59. handle alive for the lifetime of the process (until explicit
  60. uninstall support exists).
  61. Example:
  62. >>> activate_flash_attention_impl("FA4") # doctest: +SKIP
  63. """
  64. global _FLASH_ATTENTION_ACTIVE, _FLASH_ATTENTION_IMPLS
  65. restore_flash_attention_impl(
  66. _raise_warn=False
  67. ) # first restore any prev overrides (if any) to default
  68. register_fn = _FLASH_ATTENTION_IMPLS.get(impl)
  69. if register_fn is None:
  70. raise ValueError(
  71. f"Unknown flash attention impl '{impl}'. "
  72. f"Available implementations: {list_flash_attention_impls()}"
  73. )
  74. handle = register_fn()
  75. if handle is not None:
  76. _FLASH_ATTENTION_ACTIVE = (impl, handle)
  77. def list_flash_attention_impls() -> list[str]:
  78. """Return the names of all available flash attention implementations."""
  79. return sorted(_FLASH_ATTENTION_IMPLS.keys())
  80. def current_flash_attention_impl() -> str | None:
  81. """
  82. Return the currently activated flash attention impl name, if any.
  83. ``None`` indicates that no custom impl has been activated.
  84. """
  85. return (
  86. _FLASH_ATTENTION_ACTIVE[0]
  87. if _FLASH_ATTENTION_ACTIVE is not None
  88. else _FLASH_ATTENTION_ACTIVE
  89. )
  90. def restore_flash_attention_impl(_raise_warn: bool = True) -> None:
  91. """
  92. Restore the default FA2 implementation
  93. """
  94. global _FLASH_ATTENTION_ACTIVE
  95. handle = None
  96. if _FLASH_ATTENTION_ACTIVE is not None:
  97. handle = _FLASH_ATTENTION_ACTIVE[1]
  98. if handle is not None:
  99. handle.remove()
  100. elif _raise_warn:
  101. logger.warning(
  102. "Trying to restore default FA2 impl when no custom impl was activated"
  103. )
  104. _FLASH_ATTENTION_ACTIVE = None # default