hub_kernels.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib.metadata
  15. import os
  16. import re
  17. from collections.abc import Callable
  18. from contextlib import contextmanager
  19. from types import ModuleType
  20. from packaging import version as pkg_version
  21. from ..utils import ENV_VARS_TRUE_VALUES, logging
  22. from ..utils.import_utils import is_kernels_available
  23. from .flash_attention import flash_attention_forward
  24. logger = logging.get_logger(__name__)
  25. try:
  26. from kernels import (
  27. Device,
  28. LayerRepository,
  29. Mode,
  30. register_kernel_mapping,
  31. replace_kernel_forward_from_hub,
  32. )
  33. from kernels import (
  34. get_kernel as get_kernel_hub,
  35. )
  36. from kernels import (
  37. use_kernel_forward_from_hub as _kernels_use_kernel_forward_from_hub,
  38. )
  39. # Try to import FuncRepository, fallback if not available
  40. try:
  41. from kernels import FuncRepository
  42. except ImportError:
  43. FuncRepository = None
  44. # Try to import use_kernel_func_from_hub, fallback if not available
  45. try:
  46. from kernels import use_kernel_func_from_hub as _kernels_use_kernel_func_from_hub
  47. _has_use_kernel_func_from_hub = True
  48. except ImportError:
  49. _has_use_kernel_func_from_hub = False
  50. _TRANSFORMERS_USE_HUB_KERNELS = os.environ.get("USE_HUB_KERNELS", "YES").upper()
  51. _kernels_available = True
  52. _kernels_enabled = _TRANSFORMERS_USE_HUB_KERNELS in ENV_VARS_TRUE_VALUES
  53. def use_kernel_forward_from_hub(layer_name: str):
  54. if _kernels_enabled:
  55. return _kernels_use_kernel_forward_from_hub(layer_name)
  56. else:
  57. logger.warning_once(
  58. f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
  59. )
  60. return lambda cls: cls
  61. def use_kernel_func_from_hub(func_name: str):
  62. if _kernels_enabled and _has_use_kernel_func_from_hub:
  63. return _kernels_use_kernel_func_from_hub(func_name)
  64. else:
  65. if not _has_use_kernel_func_from_hub:
  66. logger.warning_once(
  67. "use_kernel_func_from_hub is not available in the installed kernels version. "
  68. "Please upgrade kernels to use this feature."
  69. )
  70. else:
  71. logger.warning_once(
  72. f"kernels hub usage is disabled through the environment USE_HUB_KERNELS={_TRANSFORMERS_USE_HUB_KERNELS}"
  73. )
  74. return lambda func: func
  75. _KERNEL_MAPPING: dict[str, dict[Device | str, LayerRepository | dict[Mode, LayerRepository]]] = {
  76. "MultiScaleDeformableAttention": {
  77. "cuda": LayerRepository(
  78. repo_id="kernels-community/deformable-detr",
  79. layer_name="MultiScaleDeformableAttention",
  80. )
  81. },
  82. "Llama4TextMoe": {
  83. "cuda": LayerRepository(
  84. repo_id="kernels-community/moe",
  85. layer_name="Llama4TextMoe",
  86. )
  87. },
  88. "RMSNorm": {
  89. "cuda": {
  90. Mode.INFERENCE: LayerRepository(
  91. repo_id="kernels-community/liger_kernels",
  92. layer_name="LigerRMSNorm",
  93. # revision="pure-layer-test",
  94. ),
  95. },
  96. "rocm": {
  97. Mode.INFERENCE: LayerRepository(
  98. repo_id="kernels-community/liger_kernels",
  99. layer_name="LigerRMSNorm",
  100. )
  101. },
  102. "xpu": {
  103. Mode.INFERENCE: LayerRepository(
  104. repo_id="kernels-community/rmsnorm",
  105. layer_name="RMSNorm",
  106. )
  107. },
  108. "mps": {
  109. Mode.INFERENCE: LayerRepository(
  110. repo_id="kernels-community/mlx_rmsnorm",
  111. layer_name="RMSNorm",
  112. )
  113. },
  114. "npu": {
  115. Mode.INFERENCE: LayerRepository(
  116. repo_id="kernels-community/liger_kernels",
  117. layer_name="LigerRMSNorm",
  118. )
  119. },
  120. },
  121. "MLP": {
  122. "cuda": LayerRepository(
  123. repo_id="medmekk/triton-llama-mlp",
  124. layer_name="TritonLlamaMLP",
  125. )
  126. },
  127. "MegaBlocksMoeMLP": {
  128. "cuda": {
  129. Mode.TRAINING: LayerRepository(
  130. repo_id="kernels-community/megablocks",
  131. layer_name="MegaBlocksMoeMLP",
  132. ),
  133. Mode.INFERENCE: LayerRepository(
  134. repo_id="kernels-community/megablocks",
  135. layer_name="MegaBlocksMoeMLP",
  136. ),
  137. },
  138. "rocm": {
  139. Mode.INFERENCE: LayerRepository(
  140. repo_id="ahadnagy/megablocks",
  141. layer_name="MegaBlocksMoeMLP",
  142. )
  143. },
  144. "xpu": {
  145. Mode.INFERENCE: LayerRepository(
  146. repo_id="kernels-community/megablocks",
  147. layer_name="MegaBlocksMoeMLP",
  148. )
  149. },
  150. "cpu": {
  151. Mode.INFERENCE: LayerRepository(
  152. repo_id="kernels-community/megablocks",
  153. layer_name="CPUMegaBlocksMoeMLP",
  154. )
  155. },
  156. },
  157. "FastGELU": {
  158. "cuda": {
  159. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  160. repo_id="kernels-community/activation",
  161. layer_name="FastGELU",
  162. version=1,
  163. )
  164. }
  165. },
  166. "QuickGELU": {
  167. "cuda": {
  168. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  169. repo_id="kernels-community/activation",
  170. layer_name="QuickGELU",
  171. version=1,
  172. )
  173. }
  174. },
  175. "NewGELU": {
  176. "cuda": {
  177. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  178. repo_id="kernels-community/activation",
  179. layer_name="NewGELU",
  180. version=1,
  181. )
  182. }
  183. },
  184. "SiLU": {
  185. "cuda": {
  186. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  187. repo_id="kernels-community/activation", layer_name="Silu", version=1
  188. )
  189. }
  190. },
  191. "GeLU": {
  192. "cuda": {
  193. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  194. repo_id="kernels-community/activation", layer_name="Gelu", version=1
  195. )
  196. }
  197. },
  198. "GeluTanh": {
  199. "cuda": {
  200. Mode.INFERENCE | Mode.TORCH_COMPILE: LayerRepository(
  201. repo_id="kernels-community/activation", layer_name="GeluTanh", version=1
  202. )
  203. }
  204. },
  205. }
  206. # Add function kernel mappings if FuncRepository is available
  207. if FuncRepository is not None:
  208. _KERNEL_MAPPING["rotary_pos_emb"] = {
  209. "xpu": {
  210. Mode.INFERENCE: FuncRepository(
  211. repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
  212. )
  213. },
  214. "cuda": {
  215. Mode.INFERENCE: FuncRepository(
  216. repo_id="kernels-community/rotary", func_name="apply_rotary_transformers"
  217. )
  218. },
  219. }
  220. def has_key(d, key):
  221. return key in d or any(isinstance(v, dict) and has_key(v, key) for v in d.values())
  222. def register_kernel_mapping_transformers(mapping=None):
  223. if mapping is None:
  224. mapping = _KERNEL_MAPPING
  225. if has_key(mapping, "xpu") and not is_kernels_available(MIN_VERSION="0.10.2"):
  226. raise ImportError(
  227. "kernels uses an incompatible version. Please install the latest version with `pip install -U kernels`."
  228. )
  229. register_kernel_mapping(mapping)
  230. except ImportError:
  231. _kernels_available = False
  232. _kernels_enabled = False
  233. # Stub to make decorators int transformers work when `kernels`
  234. # is not installed.
  235. def use_kernel_forward_from_hub(*args, **kwargs):
  236. def decorator(cls):
  237. return cls
  238. return decorator
  239. def use_kernel_func_from_hub(*args, **kwargs):
  240. def decorator(func):
  241. return func
  242. return decorator
  243. class LayerRepository:
  244. def __init__(self, *args, **kwargs):
  245. raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
  246. def replace_kernel_forward_from_hub(*args, **kwargs):
  247. raise RuntimeError(
  248. "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
  249. )
  250. def register_kernel_mapping(*args, **kwargs):
  251. raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")
  252. def register_kernel_mapping_transformers(*args, **kwargs):
  253. raise RuntimeError(
  254. "register_kernel_mapping_transformers requires `kernels` to be installed. Run `pip install kernels`."
  255. )
  256. _HUB_KERNEL_MAPPING: dict[str, dict[str, str]] = {
  257. "causal-conv1d": {"repo_id": "kernels-community/causal-conv1d", "version": 1},
  258. "mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
  259. "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1},
  260. "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1},
  261. "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1},
  262. }
  263. _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {}
  264. def is_kernel(attn_implementation: str | None) -> bool:
  265. """Check whether `attn_implementation` matches a kernel pattern from the hub."""
  266. return (
  267. attn_implementation is not None
  268. and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
  269. )
  270. def load_and_register_attn_kernel(
  271. attn_implementation: str, attention_wrapper: Callable | None = None, allow_all_kernels: bool = False
  272. ) -> ModuleType | None:
  273. """
  274. Load and register the kernel associated to `attn_implementation`.
  275. Args:
  276. attn_implementation: A string, usually a kernel repo like "kernels-community/flash-mla".
  277. attn_wrapper: a callable for the wrapper around the attention implementation. In `transformers` we
  278. have a wrapper around the `flash_attn_var_len` call, and the same goes for `sdpa` and `eager`.
  279. They just prepare the arguments properly. This is mostly used for continious batching, where we
  280. want the `paged` wrapper, which calls the paged cache.
  281. allow_all_kernels (`bool`, optional):
  282. Whether to load kernels from unverified hub repos, if it is a custom kernel outside of the `kernels-community`
  283. hub repository.
  284. """
  285. from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
  286. from ..modeling_utils import ALL_ATTENTION_FUNCTIONS
  287. actual_attn_name = attn_implementation.split("|")[1] if "|" in attn_implementation else attn_implementation
  288. if not is_kernel(actual_attn_name):
  289. return None
  290. if not _kernels_available:
  291. raise ImportError(
  292. "`kernels` is either not installed or uses an incompatible version. "
  293. "Please install the latest version with `pip install -U kernels`."
  294. )
  295. # Extract repo_id and kernel_name from the string
  296. if ":" in actual_attn_name:
  297. repo_id, kernel_name = actual_attn_name.split(":")
  298. kernel_name = kernel_name.strip()
  299. else:
  300. repo_id = actual_attn_name
  301. kernel_name = None
  302. repo_id = repo_id.strip()
  303. # extract the rev after the @ if it exists
  304. repo_id, _, rev = repo_id.partition("@")
  305. repo_id = repo_id.strip()
  306. rev = rev.strip() if rev else None
  307. # Load the kernel from hub
  308. try:
  309. kernel = get_kernel(repo_id, revision=rev, allow_all_kernels=allow_all_kernels)
  310. except Exception as e:
  311. raise ValueError(f"An error occurred while trying to load from '{repo_id}': {e}.")
  312. # correctly wrap the kernel
  313. if hasattr(kernel, "flash_attn_varlen_func"):
  314. if attention_wrapper is None:
  315. attention_wrapper = flash_attention_forward
  316. kernel_function = attention_wrapper
  317. elif kernel_name is not None:
  318. kernel_function = getattr(kernel, kernel_name)
  319. # Register the kernel as a valid attention
  320. ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
  321. ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])
  322. return kernel
  323. def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _KERNEL_MODULE_MAPPING):
  324. if kernel_name in mapping and isinstance(mapping[kernel_name], ModuleType):
  325. return mapping[kernel_name]
  326. if kernel_name not in _HUB_KERNEL_MAPPING:
  327. logger.warning_once(f"Kernel {kernel_name} not found in _HUB_KERNEL_MAPPING")
  328. mapping[kernel_name] = None
  329. return None
  330. if _kernels_available:
  331. try:
  332. repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"]
  333. revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None)
  334. version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None)
  335. kernel = get_kernel(repo_id, revision=revision, version=version)
  336. mapping[kernel_name] = kernel
  337. except FileNotFoundError:
  338. mapping[kernel_name] = None
  339. except AssertionError:
  340. # Happens when torch is built without an accelerator backend; fall back to slow path.
  341. mapping[kernel_name] = None
  342. else:
  343. # Try to import is_{kernel_name}_available from ..utils
  344. import importlib
  345. new_kernel_name = kernel_name.replace("-", "_")
  346. func_name = f"is_{new_kernel_name}_available"
  347. try:
  348. utils_mod = importlib.import_module("..utils.import_utils", __package__)
  349. is_kernel_available = getattr(utils_mod, func_name, None)
  350. except Exception:
  351. is_kernel_available = None
  352. if callable(is_kernel_available) and is_kernel_available():
  353. # Try to import the module "{kernel_name}" from parent package level
  354. try:
  355. module = importlib.import_module(f"{new_kernel_name}")
  356. mapping[kernel_name] = module
  357. return module
  358. except Exception:
  359. mapping[kernel_name] = None
  360. else:
  361. mapping[kernel_name] = None
  362. return mapping[kernel_name]
  363. def get_kernel(
  364. kernel_name: str,
  365. revision: str | None = None,
  366. version: int | str | None = None,
  367. allow_all_kernels: bool = False,
  368. ) -> ModuleType:
  369. from .. import __version__
  370. if not _kernels_available:
  371. raise ImportError(
  372. "`kernels` is either not installed or uses an incompatible version. Please install the latest version "
  373. "with `pip install -U kernels`."
  374. )
  375. repo_parent = kernel_name.split("/")[0]
  376. # all `kernels-community` repos are trusted by default!
  377. if repo_parent != "kernels-community" and not allow_all_kernels:
  378. raise ValueError(
  379. "You need to specify `allow_all_kernels=True` to use kernels outside of the `kernels-community` repository"
  380. )
  381. user_agent = {"framework": "transformers", "version": __version__, "repo_id": kernel_name}
  382. kernels_version = importlib.metadata.version("kernels")
  383. if pkg_version.parse(kernels_version) >= pkg_version.parse("0.10.4"):
  384. return get_kernel_hub(kernel_name, revision=revision, version=version, user_agent=user_agent)
  385. else:
  386. return get_kernel_hub(kernel_name, revision=revision, version=version)
  387. def use_kernelized_func(module_names: list[Callable] | Callable):
  388. """
  389. This decorator attaches the target function as an attribute of the module.
  390. The function must already be decorated with @use_kernel_func_from_hub
  391. this decorator then wraps it as an nn.Module internally.
  392. When kernelize is later applied to the full model, the function can be accessed as a regular module attribute and kernelized just like any other layer.
  393. The kernelization is performed in place, modifying the module directly.
  394. """
  395. if isinstance(module_names, Callable):
  396. module_names = [module_names]
  397. def decorator(cls):
  398. orig_init = cls.__init__
  399. def new_init(self, *args, **kwargs):
  400. orig_init(self, *args, **kwargs)
  401. # Skip attaching the kernelized submodule under DeepSpeed ZeRO-3: the coordinator traces
  402. # the module graph at init time, and a child `nn.Module` that is not actually invoked
  403. # during forward (e.g. when the model keeps calling the plain Python `apply_rotary_pos_emb`)
  404. # breaks the parameter fetch trace and raises `IndexError: pop from an empty deque`.
  405. # See https://github.com/huggingface/transformers/issues/45137
  406. from .deepspeed import is_deepspeed_zero3_enabled
  407. if is_deepspeed_zero3_enabled():
  408. return
  409. for fn in module_names:
  410. # we hardcode the name of the function to "rotary_fn" for now
  411. setattr(self, "rotary_fn", fn)
  412. cls.__init__ = new_init
  413. return cls
  414. return decorator
  415. # Whether to allow hub kernels coming from untrusted repos, i.e. repos outside `kernels-community`
  416. ALLOW_ALL_KERNELS = False
  417. @contextmanager
  418. def allow_all_hub_kernels():
  419. """
  420. Context manager used to adjust the value of the global `ALLOW_HUB_KERNELS`. This is needed, as this argument
  421. cannot be forwarded directly to the `__init__` of the models, where we set the attention implementation.
  422. """
  423. global ALLOW_ALL_KERNELS
  424. try:
  425. ALLOW_ALL_KERNELS = True
  426. yield
  427. finally:
  428. # Set back the original
  429. ALLOW_ALL_KERNELS = False
  430. __all__ = [
  431. "LayerRepository",
  432. "use_kernel_forward_from_hub",
  433. "use_kernel_func_from_hub",
  434. "register_kernel_mapping",
  435. "register_kernel_mapping_transformers",
  436. "replace_kernel_forward_from_hub",
  437. "lazy_load_kernel",
  438. "get_kernel",
  439. "use_kernelized_func",
  440. ] # type: ignore