| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib.metadata
- import os
- import re
- from collections.abc import Callable
- from contextlib import contextmanager
- from types import ModuleType
- from packaging import version as pkg_version
- from ..utils import ENV_VARS_TRUE_VALUES, logging
- from ..utils.import_utils import is_kernels_available
- from .flash_attention import flash_attention_forward
- logger = logging.get_logger(__name__)
- try:
- from kernels import (
- Device,
- LayerRepository,
- Mode,
- register_kernel_mapping,
- replace_kernel_forward_from_hub,
- )
- from kernels import (
- get_kernel as get_kernel_hub,
- )
- from kernels import (
- use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
- )
- # Try to import FuncRepository, fallback if not available
- try:
- from kernels import FuncRepository
- except ImportError:
- FuncRepository = None
- # Try to import use_kernel_func_from_hub, fallback if not available
- try:
- from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
- _has_use_kernel_func_from_hub = True
- except ImportError:
- _has_use_kernel_func_from_hub = False
- _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
- _kernels_available = True
- _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
- def use_kernel_forward_from_hub(layer_name: str):
- if _kernels_enabled:
- return _kernels_use_kernel_forward_from_hub(layer_name)
- else:
- logger.warning_once(
- f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
- )
- return lambda cls: cls
- def use_kernel_func_from_hub(func_name: str):
- if _kernels_enabled and _has_use_kernel_func_from_hub:
- return _kernels_use_kernel_func_from_hub(func_name)
- else:
- if not _has_use_kernel_func_from_hub:
- logger.warning_once(
- "use_kernel_func_from_hub is not available in the installed kernels version. "
- "Please upgrade kernels to use this feature."
- )
- else:
- logger.warning_once(
- f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
- )
- return lambda func: func
- _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository | dict[Mode, LayerRepository]]] = {
- "MultiScaleDeformableAttention": {
- "cuda": LayerRepository(
- repo_id="kernels-community/deformable-detr",
- layer_name="MultiScaleDeformableAttention",
- )
- },
- "Llama4TextMoe": {
- "cuda": LayerRepository(
- repo_id="kernels-community/moe",
- layer_name="Llama4TextMoe",
- )
- },
- "RMSNorm": {
- "cuda": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/liger_kernels",
- layer_name="LigerRMSNorm",
- # revision="pure-layer-test",
- ),
- },
- "rocm": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/liger_kernels",
- layer_name="LigerRMSNorm",
- )
- },
- "xpu": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/rmsnorm",
- layer_name="RMSNorm",
- )
- },
- "mps": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/mlx_rmsnorm",
- layer_name="RMSNorm",
- )
- },
- "npu": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/liger_kernels",
- layer_name="LigerRMSNorm",
- )
- },
- },
- "MLP": {
- "cuda": LayerRepository(
- repo_id="medmekk/triton-llama-mlp",
- layer_name="TritonLlamaMLP",
- )
- },
- "MegaBlocksMoeMLP": {
- "cuda": {
- Mode.TRAINING: LayerRepository(
- repo_id="kernels-community/megablocks",
- layer_name="MegaBlocksMoeMLP",
- ),
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/megablocks",
- layer_name="MegaBlocksMoeMLP",
- ),
- },
- "rocm": {
- Mode.INFERENCE: LayerRepository(
- repo_id="ahadnagy/megablocks",
- layer_name="MegaBlocksMoeMLP",
- )
- },
- "xpu": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/megablocks",
- layer_name="MegaBlocksMoeMLP",
- )
- },
- "cpu": {
- Mode.INFERENCE: LayerRepository(
- repo_id="kernels-community/megablocks",
- layer_name="CPUMegaBlocksMoeMLP",
- )
- },
- },
- "FastGELU": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation",
- layer_name="FastGELU",
- version=1,
- )
- }
- },
- "QuickGELU": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation",
- layer_name="QuickGELU",
- version=1,
- )
- }
- },
- "NewGELU": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation",
- layer_name="NewGELU",
- version=1,
- )
- }
- },
- "SiLU": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation", layer_name="Silu", version=1
- )
- }
- },
- "GeLU": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation", layer_name="Gelu", version=1
- )
- }
- },
- "GeluTanh": {
- "cuda": {
- Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
- repo_id="kernels-community/activation", layer_name="GeluTanh", version=1
- )
- }
- },
- }
- # Add function kernel mappings if FuncRepository is available
- if FuncRepository is not None:
- _KERNEL_MAPPING["rotary_pos_emb"] = {
- "xpu": {
- Mode.INFERENCE: FuncRepository(
- repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
- )
- },
- "cuda": {
- Mode.INFERENCE: FuncRepository(
- repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
- )
- },
- }
- def has_key(d, key):
- return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
- def register_kernel_mapping_transformers(mapping=None):
- if mapping is None:
- mapping = _KERNEL_MAPPING
- if has_key(mapping, "xpu") and not is_kernels_available(MIN_VERSION="0.10.2"):
- raise ImportError(
- "kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`."
- )
- register_kernel_mapping(mapping)
- except ImportError:
- _kernels_available = False
- _kernels_enabled = False
- # Stub to make decorators int transformers work when `kernels`
- # is not installed.
- def use_kernel_forward_from_hub(*args, **kwargs):
- def decorator(cls):
- return cls
- return decorator
- def use_kernel_func_from_hub(*args, **kwargs):
- def decorator(func):
- return func
- return decorator
- class LayerRepository:
- def __init__(self, *args, **kwargs):
- raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
- def replace_kernel_forward_from_hub(*args, **kwargs):
- raise RuntimeError(
- "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
- )
- def register_kernel_mapping(*args, **kwargs):
- raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
- def register_kernel_mapping_transformers(*args, **kwargs):
- raise RuntimeError(
- "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
- )
- _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
- "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d", "version": 1},
- "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
- "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
- "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
- "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
- }
- _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
- def is_kernel(attn_implementation: str | None) -> bool:
- """Check whether `attn_implementation` matches a kernel pattern from the hub."""
- return (
- attn_implementation is not None
- and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
- )
- def load_and_register_attn_kernel(
- attn_implementation: str, attention_wrapper: Callable | None = None, allow_all_kernels: bool = False
- ) -> ModuleType | None:
- """
- Load and register the kernel associated to `attn_implementation`.
- Args:
- attn_implementation: A string, usually a kernel repo like "kernels-community/flash-mla".
- attn_wrapper: a callable for the wrapper around the attention implementation. In `transformers` we
- have a wrapper around the `flash_attn_var_len` call, and the same goes for `sdpa` and `eager`.
- They just prepare the arguments properly. This is mostly used for continious batching, where we
- want the `paged` wrapper, which calls the paged cache.
- allow_all_kernels (`bool`, optional):
- Whether to load kernels from unverified hub repos, if it is a custom kernel outside of the `kernels-community`
- hub repository.
- """
- from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
- from ..modeling_utils import ALL_ATTENTION_FUNCTIONS
- actual_attn_name = attn_implementation.split("|")[1] if "|" in attn_implementation else attn_implementation
- if not is_kernel(actual_attn_name):
- return None
- if not _kernels_available:
- raise ImportError(
- "`kernels` is either not installed or uses an incompatible version. "
- "Please install the latest version with `pip install -U kernels`."
- )
- # Extract repo_id and kernel_name from the string
- if ":" in actual_attn_name:
- repo_id, kernel_name = actual_attn_name.split(":")
- kernel_name = kernel_name.strip()
- else:
- repo_id = actual_attn_name
- kernel_name = None
- repo_id = repo_id.strip()
- # extract the rev after the @ if it exists
- repo_id, _, rev = repo_id.partition("@")
- repo_id = repo_id.strip()
- rev = rev.strip() if rev else None
- # Load the kernel from hub
- try:
- kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=allow_all_kernels)
- except Exception as e:
- raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
- # correctly wrap the kernel
- if hasattr(kernel, "flash_attn_varlen_func"):
- if attention_wrapper is None:
- attention_wrapper = flash_attention_forward
- kernel_function = attention_wrapper
- elif kernel_name is not None:
- kernel_function = getattr(kernel, kernel_name)
- # Register the kernel as a valid attention
- ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
- ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
- return kernel
- def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _KERNEL_MODULE_MAPPING):
- if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
- return mapping[kernel_name]
- if kernel_name not in _HUB_KERNEL_MAPPING:
- logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
- mapping[kernel_name] = None
- return None
- if _kernels_available:
- try:
- repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
- revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
- version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
- kernel = get_kernel(repo_id, revision=revision, version=version)
- mapping[kernel_name] = kernel
- except FileNotFoundError:
- mapping[kernel_name] = None
- except AssertionError:
- # Happens when torch is built without an accelerator backend; fall back to slow path.
- mapping[kernel_name] = None
- else:
- # Try to import is_{kernel_name}_available from ..utils
- import importlib
- new_kernel_name = kernel_name.replace("-", "_")
- func_name = f"is_{new_kernel_name}_available"
- try:
- utils_mod = importlib.import_module("..utils.import_utils", __package__)
- is_kernel_available = getattr(utils_mod, func_name, None)
- except Exception:
- is_kernel_available = None
- if callable(is_kernel_available) and is_kernel_available():
- # Try to import the module "{kernel_name}" from parent package level
- try:
- module = importlib.import_module(f"{new_kernel_name}")
- mapping[kernel_name] = module
- return module
- except Exception:
- mapping[kernel_name] = None
- else:
- mapping[kernel_name] = None
- return mapping[kernel_name]
- def get_kernel(
- kernel_name: str,
- revision: str | None = None,
- version: int | str | None = None,
- allow_all_kernels: bool = False,
- ) -> ModuleType:
- from .. import __version__
- if not _kernels_available:
- raise ImportError(
- "`kernels` is either not installed or uses an incompatible version. Please install the latest version "
- "with `pip install -U kernels`."
- )
- repo_parent = kernel_name.split("/")[0]
- # all `kernels-community` repos are trusted by default!
- if repo_parent != "kernels-community" and not allow_all_kernels:
- raise ValueError(
- "You need to specify `allow_all_kernels=True` to use kernels outside of the `kernels-community` repository"
- )
- user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name}
- kernels_version = importlib.metadata.version("kernels")
- if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"):
- return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent)
- else:
- return get_kernel_hub(kernel_name, revision=revision, version=version)
- def use_kernelized_func(module_names: list[Callable] | Callable):
- """
- This decorator attaches the target function as an attribute of the module.
- The function must already be decorated with @use_kernel_func_from_hub
- this decorator then wraps it as an nn.Module internally.
- When kernelize is later applied to the full model, the function can be accessed as a regular module attribute and kernelized just like any other layer.
- The kernelization is performed in place, modifying the module directly.
- """
- if isinstance(module_names, Callable):
- module_names = [module_names]
- def decorator(cls):
- orig_init = cls.__init__
- def new_init(self, *args, **kwargs):
- orig_init(self, *args, **kwargs)
- # Skip attaching the kernelized submodule under DeepSpeed ZeRO-3: the coordinator traces
- # the module graph at init time, and a child `nn.Module` that is not actually invoked
- # during forward (e.g. when the model keeps calling the plain Python `apply_rotary_pos_emb`)
- # breaks the parameter fetch trace and raises `IndexError: pop from an empty deque`.
- # See https://github.com/huggingface/transformers/issues/45137
- from .deepspeed import is_deepspeed_zero3_enabled
- if is_deepspeed_zero3_enabled():
- return
- for fn in module_names:
- # we hardcode the name of the function to "rotary_fn" for now
- setattr(self, "rotary_fn", fn)
- cls.__init__ = new_init
- return cls
- return decorator
- # Whether to allow hub kernels coming from untrusted repos, i.e. repos outside `kernels-community`
- ALLOW_ALL_KERNELS = False
- @contextmanager
- def allow_all_hub_kernels():
- """
- Context manager used to adjust the value of the global `ALLOW_HUB_KERNELS`. This is needed, as this argument
- cannot be forwarded directly to the `__init__` of the models, where we set the attention implementation.
- """
- global ALLOW_ALL_KERNELS
- try:
- ALLOW_ALL_KERNELS = True
- yield
- finally:
- # Set back the original
- ALLOW_ALL_KERNELS = False
- __all__ = [
- "LayerRepository",
- "use_kernel_forward_from_hub",
- "use_kernel_func_from_hub",
- "register_kernel_mapping",
- "register_kernel_mapping_transformers",
- "replace_kernel_forward_from_hub",
- "lazy_load_kernel",
- "get_kernel",
- "use_kernelized_func",
- ] # type: ignore
|