quantizer_mxfp4.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils.quantization_config import Mxfp4Config
  19. from ..utils import (
  20. is_accelerate_available,
  21. is_kernels_available,
  22. is_torch_available,
  23. is_triton_available,
  24. logging,
  25. )
  26. from .quantizers_utils import get_module_from_name
  27. if is_torch_available():
  28. import torch
  29. from ..core_model_loading import WeightConverter
  30. logger = logging.get_logger(__name__)
  31. triton_kernels_hub = None
  32. class Mxfp4HfQuantizer(HfQuantizer):
  33. """
  34. FP4 quantization using fbgemm kernels
  35. """
  36. requires_calibration = False
  37. quantization_config: "Mxfp4Config"
  38. def __init__(self, quantization_config, **kwargs):
  39. super().__init__(quantization_config, **kwargs)
  40. self.triton_kernels_hub = None
  41. def _lazy_import_kernels(self):
  42. """Lazy import and initialize kernels only when needed"""
  43. if self.triton_kernels_hub is None:
  44. try:
  45. from ..integrations.hub_kernels import get_kernel
  46. self.triton_kernels_hub = get_kernel("kernels-community/gpt-oss-triton-kernels")
  47. except ImportError:
  48. raise ImportError("kernels package is required for MXFP4 quantization")
  49. return self.triton_kernels_hub
  50. def validate_environment(self, *args, **kwargs):
  51. if not is_torch_available():
  52. raise ImportError(
  53. "Using mxfp4 quantization requires torch"
  54. "Please install the latest version of torch ( pip install --upgrade torch )"
  55. )
  56. if self.quantization_config.dequantize:
  57. return
  58. if not is_accelerate_available():
  59. raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
  60. device = torch.accelerator.current_accelerator() or torch.device("cpu")
  61. if device.type not in ["cuda", "xpu", "cpu"]:
  62. if self.pre_quantized:
  63. logger.warning_once(
  64. f"Using MXFP4 quantized models requires model on cuda/xpu/cpu, but found {device}, we will default to dequantizing the model to bf16. To use mxfp4, please disable the current accelerator."
  65. )
  66. self.quantization_config.dequantize = True
  67. return
  68. else:
  69. raise RuntimeError(
  70. f"Quantizing a model using MXFP4 requires model on cuda/xpu/cpu, but found {device}. To use mxfp4, please disable the current accelerator."
  71. )
  72. if torch.xpu.is_available():
  73. is_device_supported_mxfp4 = True
  74. triton_available = is_triton_available("3.5.0")
  75. kernels_installed = is_kernels_available()
  76. elif torch.cuda.is_available():
  77. compute_capability = torch.cuda.get_device_capability()
  78. is_device_supported_mxfp4 = compute_capability >= (7, 5)
  79. triton_available = is_triton_available("3.4.0")
  80. kernels_installed = is_kernels_available()
  81. elif device.type == "cpu":
  82. is_device_supported_mxfp4 = True
  83. triton_available = is_triton_available("3.5.0")
  84. kernels_installed = is_kernels_available()
  85. else:
  86. is_device_supported_mxfp4 = False
  87. triton_available = False
  88. kernels_installed = False
  89. if self.pre_quantized:
  90. if not is_device_supported_mxfp4:
  91. logger.warning_once(
  92. "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 "
  93. "(e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series). "
  94. "We will default to dequantizing the model to bf16."
  95. )
  96. self.quantization_config.dequantize = True
  97. return
  98. if not triton_available:
  99. logger.warning_once(
  100. "MXFP4 quantization requires Triton: CUDA requires Triton >= 3.4.0, "
  101. "XPU/CPU requires Triton >= 3.5.0. Please install triton: `pip install triton`. "
  102. "We will default to dequantizing the model to bf16."
  103. )
  104. self.quantization_config.dequantize = True
  105. return
  106. if not kernels_installed:
  107. logger.warning_once(
  108. "MXFP4 quantization requires the `kernels` package: "
  109. "`pip install kernels>=0.12.0`. "
  110. "We will default to dequantizing the model to bf16."
  111. )
  112. self.quantization_config.dequantize = True
  113. return
  114. elif not is_device_supported_mxfp4:
  115. raise ValueError(
  116. "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 "
  117. "(e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) or CPU"
  118. )
  119. elif not triton_available:
  120. raise ValueError(
  121. "MXFP4 quantization requires Triton: CUDA requires Triton >= 3.4.0, "
  122. "XPU/CPU requires Triton >= 3.5.0. Please install triton: `pip install triton`"
  123. )
  124. elif not kernels_installed:
  125. raise ValueError("MXFP4 quantization requires the `kernels` package: `pip install kernels>=0.12.0`")
  126. if not self.pre_quantized:
  127. self._lazy_import_kernels()
  128. device_map = kwargs.get("device_map")
  129. if device_map is not None and isinstance(device_map, dict):
  130. if not self.pre_quantized and "disk" in device_map.values():
  131. raise ValueError(
  132. "You are attempting to load an FP4 model with a device_map that contains a disk device."
  133. "This is not supported when the model is quantized on the fly. "
  134. "Please use a quantized checkpoint or remove the disk device from the device_map."
  135. )
  136. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  137. from ..integrations import Mxfp4GptOssExperts
  138. module, tensor_name = get_module_from_name(model, param_name)
  139. if isinstance(module, Mxfp4GptOssExperts):
  140. if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
  141. return False
  142. return True
  143. return False
  144. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  145. # clean cache due to triton ops
  146. if torch.cuda.is_available():
  147. torch.cuda.empty_cache()
  148. elif torch.xpu.is_available():
  149. torch.xpu.empty_cache()
  150. def _process_model_before_weight_loading(
  151. self,
  152. model: "PreTrainedModel",
  153. use_kernels: bool = False,
  154. **kwargs,
  155. ):
  156. from ..integrations import replace_with_mxfp4_linear
  157. # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
  158. # only CPU kernels can work with pre-quantized models
  159. device = torch.accelerator.current_accelerator() or torch.device("cpu")
  160. if use_kernels and device.type not in ["cpu"]:
  161. logger.warning_once(
  162. "You are using full precision kernels, we will dequantize the model to bf16. "
  163. "To use the quantized model with quantization kernels, please set use_kernels=False"
  164. )
  165. self.quantization_config.dequantize = True
  166. if not use_kernels and device.type in ["cpu"]:
  167. logger.warning_once(
  168. "MXFP4 inference on CPU requires use_kernels=True, but use_kernels is disabled. "
  169. "We will dequantize the model to bf16. To run MXFP4 natively on CPU, please set use_kernels=True."
  170. )
  171. self.quantization_config.dequantize = True
  172. self.modules_to_not_convert = self.get_modules_to_not_convert(
  173. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  174. )
  175. model = replace_with_mxfp4_linear(
  176. model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
  177. )
  178. def update_tp_plan(self, config):
  179. if "GptOssConfig" in config.__class__.__name__:
  180. if getattr(config, "base_model_tp_plan", None) is not None:
  181. config.base_model_tp_plan.update(
  182. {
  183. "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
  184. "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
  185. "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
  186. "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
  187. }
  188. )
  189. return config
  190. def update_ep_plan(self, config):
  191. if "GptOssConfig" in config.__class__.__name__:
  192. if getattr(config, "base_model_ep_plan", None) is not None:
  193. config.base_model_ep_plan.update(
  194. {
  195. "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
  196. "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
  197. "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
  198. "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
  199. }
  200. )
  201. return config
  202. def get_state_dict_and_metadata(self, model):
  203. from ..integrations import Mxfp4GptOssExperts
  204. state_dict = model.state_dict()
  205. num_local_experts = getattr(model.config, "num_local_experts", 32)
  206. hidden_size = getattr(model.config, "hidden_size", 2880)
  207. for name, module in model.named_modules():
  208. if not (
  209. isinstance(module, Mxfp4GptOssExperts)
  210. and hasattr(module, "gate_up_proj")
  211. and hasattr(module, "down_proj")
  212. ):
  213. continue
  214. for proj in ("gate_up_proj", "down_proj"):
  215. triton_tensor = getattr(module, proj)
  216. precision_config = getattr(module, f"{proj}_precision_config")
  217. blocks = triton_tensor.storage.layout.unswizzle_data(triton_tensor.storage.data).transpose(-1, -2)
  218. if proj == "gate_up_proj":
  219. blocks = blocks.reshape(num_local_experts, -1, 90, 16)
  220. else:
  221. blocks = blocks.reshape(num_local_experts, hidden_size, 90, -1)
  222. scales = precision_config.weight_scale.storage.layout.unswizzle_data(
  223. precision_config.weight_scale.storage.data
  224. ).transpose(-1, -2)
  225. state_dict[f"{name}.{proj}_blocks"] = blocks
  226. state_dict[f"{name}.{proj}_scales"] = scales
  227. metadata = {}
  228. return state_dict, metadata
  229. def is_serializable(self):
  230. return True
  231. @property
  232. def is_trainable(self) -> bool:
  233. logger.warning_once(
  234. "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
  235. )
  236. return False
  237. def get_quantize_ops(self):
  238. from ..integrations.mxfp4 import Mxfp4Quantize
  239. return Mxfp4Quantize(self)
  240. def get_weight_conversions(self):
  241. from ..integrations.mxfp4 import Mxfp4Dequantize, Mxfp4Deserialize
  242. if self.pre_quantized and self.quantization_config.dequantize:
  243. return [
  244. WeightConverter(
  245. source_patterns=["down_proj_blocks", "down_proj_scales"],
  246. target_patterns=r"down_proj$",
  247. operations=[Mxfp4Dequantize(self)],
  248. ),
  249. WeightConverter(
  250. source_patterns=["gate_up_proj_blocks", "gate_up_proj_scales"],
  251. target_patterns=["gate_up_proj$"],
  252. operations=[Mxfp4Dequantize(self)],
  253. ),
  254. ]
  255. return [
  256. WeightConverter(
  257. source_patterns=["gate_up_proj_blocks", "gate_up_proj_scales"],
  258. target_patterns=r"gate_up_proj$",
  259. operations=[Mxfp4Deserialize(self)],
  260. ),
  261. WeightConverter(
  262. source_patterns=["down_proj_blocks", "down_proj_scales"],
  263. target_patterns=r"down_proj$",
  264. operations=[Mxfp4Deserialize(self)],
  265. ),
  266. ]