kernel_config.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import PushToHubMixin
  15. def infer_device(model):
  16. """
  17. Infers the device type from the model parameters.
  18. Args:
  19. model: The model instance.
  20. Returns:
  21. The device type.
  22. """
  23. EXAMPLE_MAPPING = """
  24. {
  25. "RMSNorm": {
  26. "cuda":
  27. "kernels-community/layer_norm:LlamaRMSNorm",
  28. ...
  29. },
  30. ...
  31. }
  32. """
  33. try:
  34. param = next(model.parameters())
  35. except StopIteration:
  36. raise ValueError(
  37. f"Cannot determine model device, please provide a device to the mapping. Example: {EXAMPLE_MAPPING}"
  38. )
  39. dev_type = param.device.type
  40. if dev_type == "cuda":
  41. # Refine based on actual platform
  42. from ..utils import is_torch_available
  43. if is_torch_available():
  44. import torch
  45. if getattr(torch, "version").hip is not None:
  46. return "rocm"
  47. return dev_type
  48. def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
  49. from kernels import LayerRepository
  50. if device not in ["cuda", "rocm", "xpu", "npu", "neuron"]:
  51. raise ValueError(f"Only cuda, rocm, xpu, npu and neuron devices supported, got: {device}")
  52. repo_layer_name = repo_name.split(":")[1]
  53. repo_id = repo_name.split(":")[0]
  54. compatible_mapping[layer_name] = {
  55. device: {
  56. mode: LayerRepository(
  57. repo_id=repo_id,
  58. layer_name=repo_layer_name,
  59. )
  60. }
  61. }
  62. def add_to_mapping_local(layer_name, device, repo_name, mode, compatible_mapping):
  63. from pathlib import Path
  64. from kernels import LocalLayerRepository
  65. if device not in ["cuda", "rocm", "xpu", "npu", "neuron"]:
  66. raise ValueError(f"Only cuda, rocm, xpu, npu and neuron devices supported, got: {device}")
  67. repo_layer_name = repo_name.split(":")[1]
  68. repo_path = repo_name.split(":")[0]
  69. repo_package_name = repo_path.split("/")[-1]
  70. compatible_mapping[layer_name] = {
  71. device: {
  72. mode: LocalLayerRepository(
  73. repo_path=Path(repo_path),
  74. package_name=repo_package_name,
  75. layer_name=repo_layer_name,
  76. )
  77. }
  78. }
  79. class KernelConfig(PushToHubMixin):
  80. """
  81. Kernel configuration class. This class is used to configure the kernel mapping for a model.
  82. """
  83. def __init__(self, kernel_mapping=None, use_local_kernel=False):
  84. self.kernel_mapping = kernel_mapping if kernel_mapping is not None else {}
  85. self.registered_layer_names = {}
  86. self.use_local_kernel = use_local_kernel
  87. def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revision=None):
  88. from kernels import LayerRepository
  89. self.kernel_mapping[registered_name] = {
  90. device: {
  91. mode: LayerRepository(
  92. repo_id=repo_id,
  93. layer_name=layer_name,
  94. revision=revision,
  95. )
  96. }
  97. }
  98. def store_registered_layer_names(self, model):
  99. for name, module in model.named_modules():
  100. if hasattr(module, "kernel_layer_name"):
  101. self.registered_layer_names[name] = module.kernel_layer_name
  102. def sanitize_kernel_mapping(self, model):
  103. """
  104. Validates the kernel_mapping to ensure that:
  105. 1. Each layer_name in the mapping is registered in the model (i.e., the model contains a module with a matching kernel_layer_name).
  106. 2. Each kernel value is either a string of the form 'org/repo:layer_name' or a dict mapping device types ("cuda", "rocm", "xpu", "npu") to such strings.
  107. 3. Each device key in a dict is one of "cuda", "rocm", "xpu", or "npu".
  108. 4. Each repo_name is a valid repository and layer name in the format 'org/repo:layer_name' (i.e., a string containing both a slash and a colon).
  109. 5. If a local path is detected, it should be in the format '/abs/path:layer_name'. The absolute path must include the `package_name`, like "/home/user/layer_norm".
  110. Args:
  111. model: The model instance whose modules are checked for registered kernel_layer_name attributes.
  112. Raises:
  113. ValueError: If a layer_name is not registered in the model, if a device is not supported,
  114. or if a repo_name is not a valid 'org/repo:layer_name' string.
  115. """
  116. MAPPING_FORMAT = """
  117. For single device form remote
  118. {
  119. "RMSNorm":
  120. "kernels-community/layer_norm:LlamaRMSNorm",
  121. ...
  122. },
  123. For multiple devices form remote
  124. {
  125. "RMSNorm": {
  126. "cuda":
  127. "kernels-community/layer_norm:LlamaRMSNorm",
  128. "rocm":
  129. "kernels-community/layer_norm:LlamaRMSNorm",
  130. ...
  131. },
  132. ...
  133. }
  134. For single device form local
  135. {
  136. "RMSNorm":
  137. "/abs/path:LlamaRMSNorm",
  138. ...
  139. },
  140. For multiple devices form local
  141. {
  142. "RMSNorm": {
  143. "cuda":
  144. "/abs/path:LlamaRMSNorm",
  145. "rocm":
  146. "/abs/path:LlamaRMSNorm",
  147. ...
  148. },
  149. ...
  150. }
  151. """
  152. self.store_registered_layer_names(model)
  153. # Validate that the kernel mapping is a dict
  154. if not isinstance(self.kernel_mapping, dict):
  155. raise ValueError(
  156. f"Kernel mapping must be a dict of the following format: {MAPPING_FORMAT}, got: {type(self.kernel_mapping)}"
  157. )
  158. for layer_name, kernel in self.kernel_mapping.items():
  159. if layer_name not in self.registered_layer_names.values():
  160. raise ValueError(
  161. f"Layer {layer_name} is not registered in the model, please register it first using use_kernel_forward_from_hub"
  162. )
  163. if isinstance(kernel, str):
  164. if "/" not in kernel or ":" not in kernel:
  165. raise ValueError(
  166. f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name' or '/abs/path:layer_name'), got: {kernel}"
  167. )
  168. elif isinstance(kernel, dict):
  169. for device, repo_name in kernel.items():
  170. if device not in ["cuda", "rocm", "xpu", "npu", "neuron"]:
  171. raise ValueError(f"Only cuda, rocm, xpu, npu and neuron devices supported, got: {device}")
  172. if not isinstance(repo_name, str) or "/" not in repo_name or ":" not in repo_name:
  173. raise ValueError(
  174. f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name' or '/abs/path:layer_name'), got: {repo_name}"
  175. )
  176. else:
  177. raise ValueError(f"Kernel mapping must follow the format: {MAPPING_FORMAT}, got: {kernel}")
  178. def create_compatible_mapping(self, model, compile=False):
  179. """
  180. Transforms a simple kernel_mapping of the form:
  181. {
  182. "RMSNorm":
  183. "kernels-community/layer_norm:LlamaRMSNorm",
  184. ...
  185. },
  186. or for local path:
  187. {
  188. "RMSNorm":
  189. "/home/user/liger_kernels:LigerRMSNorm",
  190. ...
  191. },
  192. into a nested mapping:
  193. {
  194. "RMSNorm": {
  195. "cuda": {
  196. Mode.INFERENCE: LayerRepository(
  197. repo_id="kernels-community/layer_norm",
  198. layer_name="LlamaRMSNorm",
  199. )
  200. }
  201. }
  202. }
  203. or for local path:
  204. {
  205. "RMSNorm": {
  206. "cuda": {
  207. Mode.INFERENCE: LocalLayerRepository(
  208. repo_path=Path("/home/user/liger_kernels"),
  209. package_name="liger_kernels",
  210. layer_name="LigerRMSNorm",
  211. )
  212. }
  213. }
  214. }
  215. that's compatible with the kernels library.
  216. The device is inferred from the model's parameters if not provided.
  217. The Mode is inferred from the model's training state.
  218. """
  219. from kernels import Mode
  220. compatible_mapping = {}
  221. current_device = infer_device(model)
  222. for layer_name, kernel in self.kernel_mapping.items():
  223. # Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
  224. mode = Mode.TRAINING if model.training else Mode.INFERENCE
  225. if compile:
  226. mode = mode | Mode.TORCH_COMPILE
  227. if isinstance(kernel, str):
  228. repo_name = kernel
  229. if not self.use_local_kernel:
  230. add_to_mapping(layer_name, current_device, repo_name, mode, compatible_mapping)
  231. else:
  232. add_to_mapping_local(layer_name, current_device, repo_name, mode, compatible_mapping)
  233. elif isinstance(kernel, dict):
  234. for device, repo_name in kernel.items():
  235. if device != current_device:
  236. continue
  237. if not self.use_local_kernel:
  238. add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
  239. else:
  240. add_to_mapping_local(layer_name, device, repo_name, mode, compatible_mapping)
  241. self.kernel_mapping = compatible_mapping