| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import TYPE_CHECKING
- from .base import HfQuantizer
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..utils.quantization_config import Mxfp4Config
- from ..utils import (
- is_accelerate_available,
- is_kernels_available,
- is_torch_available,
- is_triton_available,
- logging,
- )
- from .quantizers_utils import get_module_from_name
- if is_torch_available():
- import torch
- from ..core_model_loading import WeightConverter
- logger = logging.get_logger(__name__)
- triton_kernels_hub = None
- class Mxfp4HfQuantizer(HfQuantizer):
- """
- FP4 quantization using fbgemm kernels
- """
- requires_calibration = False
- quantization_config: "Mxfp4Config"
- def __init__(self, quantization_config, **kwargs):
- super().__init__(quantization_config, **kwargs)
- self.triton_kernels_hub = None
- def _lazy_import_kernels(self):
- """Lazy import and initialize kernels only when needed"""
- if self.triton_kernels_hub is None:
- try:
- from ..integrations.hub_kernels import get_kernel
- self.triton_kernels_hub = get_kernel("kernels-community/gpt-oss-triton-kernels")
- except ImportError:
- raise ImportError("kernels package is required for MXFP4 quantization")
- return self.triton_kernels_hub
- def validate_environment(self, *args, **kwargs):
- if not is_torch_available():
- raise ImportError(
- "Using mxfp4 quantization requires torch"
- "Please install the latest version of torch ( pip install --upgrade torch )"
- )
- if self.quantization_config.dequantize:
- return
- if not is_accelerate_available():
- raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
- device = torch.accelerator.current_accelerator() or torch.device("cpu")
- if device.type not in ["cuda", "xpu", "cpu"]:
- if self.pre_quantized:
- logger.warning_once(
- f"Using MXFP4 quantized models requires model on cuda/xpu/cpu, but found {device}, we will default to dequantizing the model to bf16. To use mxfp4, please disable the current accelerator."
- )
- self.quantization_config.dequantize = True
- return
- else:
- raise RuntimeError(
- f"Quantizing a model using MXFP4 requires model on cuda/xpu/cpu, but found {device}. To use mxfp4, please disable the current accelerator."
- )
- if torch.xpu.is_available():
- is_device_supported_mxfp4 = True
- triton_available = is_triton_available("3.5.0")
- kernels_installed = is_kernels_available()
- elif torch.cuda.is_available():
- compute_capability = torch.cuda.get_device_capability()
- is_device_supported_mxfp4 = compute_capability >= (7, 5)
- triton_available = is_triton_available("3.4.0")
- kernels_installed = is_kernels_available()
- elif device.type == "cpu":
- is_device_supported_mxfp4 = True
- triton_available = is_triton_available("3.5.0")
- kernels_installed = is_kernels_available()
- else:
- is_device_supported_mxfp4 = False
- triton_available = False
- kernels_installed = False
- if self.pre_quantized:
- if not is_device_supported_mxfp4:
- logger.warning_once(
- "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 "
- "(e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series). "
- "We will default to dequantizing the model to bf16."
- )
- self.quantization_config.dequantize = True
- return
- if not triton_available:
- logger.warning_once(
- "MXFP4 quantization requires Triton: CUDA requires Triton >= 3.4.0, "
- "XPU/CPU requires Triton >= 3.5.0. Please install triton: `pip install triton`. "
- "We will default to dequantizing the model to bf16."
- )
- self.quantization_config.dequantize = True
- return
- if not kernels_installed:
- logger.warning_once(
- "MXFP4 quantization requires the `kernels` package: "
- "`pip install kernels>=0.12.0`. "
- "We will default to dequantizing the model to bf16."
- )
- self.quantization_config.dequantize = True
- return
- elif not is_device_supported_mxfp4:
- raise ValueError(
- "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 "
- "(e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) or CPU"
- )
- elif not triton_available:
- raise ValueError(
- "MXFP4 quantization requires Triton: CUDA requires Triton >= 3.4.0, "
- "XPU/CPU requires Triton >= 3.5.0. Please install triton: `pip install triton`"
- )
- elif not kernels_installed:
- raise ValueError("MXFP4 quantization requires the `kernels` package: `pip install kernels>=0.12.0`")
- if not self.pre_quantized:
- self._lazy_import_kernels()
- device_map = kwargs.get("device_map")
- if device_map is not None and isinstance(device_map, dict):
- if not self.pre_quantized and "disk" in device_map.values():
- raise ValueError(
- "You are attempting to load an FP4 model with a device_map that contains a disk device."
- "This is not supported when the model is quantized on the fly. "
- "Please use a quantized checkpoint or remove the disk device from the device_map."
- )
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
- from ..integrations import Mxfp4GptOssExperts
- module, tensor_name = get_module_from_name(model, param_name)
- if isinstance(module, Mxfp4GptOssExperts):
- if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
- return False
- return True
- return False
- def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- # clean cache due to triton ops
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- elif torch.xpu.is_available():
- torch.xpu.empty_cache()
- def _process_model_before_weight_loading(
- self,
- model: "PreTrainedModel",
- use_kernels: bool = False,
- **kwargs,
- ):
- from ..integrations import replace_with_mxfp4_linear
- # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
- # only CPU kernels can work with pre-quantized models
- device = torch.accelerator.current_accelerator() or torch.device("cpu")
- if use_kernels and device.type not in ["cpu"]:
- logger.warning_once(
- "You are using full precision kernels, we will dequantize the model to bf16. "
- "To use the quantized model with quantization kernels, please set use_kernels=False"
- )
- self.quantization_config.dequantize = True
- if not use_kernels and device.type in ["cpu"]:
- logger.warning_once(
- "MXFP4 inference on CPU requires use_kernels=True, but use_kernels is disabled. "
- "We will dequantize the model to bf16. To run MXFP4 natively on CPU, please set use_kernels=True."
- )
- self.quantization_config.dequantize = True
- self.modules_to_not_convert = self.get_modules_to_not_convert(
- model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
- )
- model = replace_with_mxfp4_linear(
- model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
- )
- def update_tp_plan(self, config):
- if "GptOssConfig" in config.__class__.__name__:
- if getattr(config, "base_model_tp_plan", None) is not None:
- config.base_model_tp_plan.update(
- {
- "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
- }
- )
- return config
- def update_ep_plan(self, config):
- if "GptOssConfig" in config.__class__.__name__:
- if getattr(config, "base_model_ep_plan", None) is not None:
- config.base_model_ep_plan.update(
- {
- "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
- }
- )
- return config
- def get_state_dict_and_metadata(self, model):
- from ..integrations import Mxfp4GptOssExperts
- state_dict = model.state_dict()
- num_local_experts = getattr(model.config, "num_local_experts", 32)
- hidden_size = getattr(model.config, "hidden_size", 2880)
- for name, module in model.named_modules():
- if not (
- isinstance(module, Mxfp4GptOssExperts)
- and hasattr(module, "gate_up_proj")
- and hasattr(module, "down_proj")
- ):
- continue
- for proj in ("gate_up_proj", "down_proj"):
- triton_tensor = getattr(module, proj)
- precision_config = getattr(module, f"{proj}_precision_config")
- blocks = triton_tensor.storage.layout.unswizzle_data(triton_tensor.storage.data).transpose(-1, -2)
- if proj == "gate_up_proj":
- blocks = blocks.reshape(num_local_experts, -1, 90, 16)
- else:
- blocks = blocks.reshape(num_local_experts, hidden_size, 90, -1)
- scales = precision_config.weight_scale.storage.layout.unswizzle_data(
- precision_config.weight_scale.storage.data
- ).transpose(-1, -2)
- state_dict[f"{name}.{proj}_blocks"] = blocks
- state_dict[f"{name}.{proj}_scales"] = scales
- metadata = {}
- return state_dict, metadata
- def is_serializable(self):
- return True
- @property
- def is_trainable(self) -> bool:
- logger.warning_once(
- "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
- )
- return False
- def get_quantize_ops(self):
- from ..integrations.mxfp4 import Mxfp4Quantize
- return Mxfp4Quantize(self)
- def get_weight_conversions(self):
- from ..integrations.mxfp4 import Mxfp4Dequantize, Mxfp4Deserialize
- if self.pre_quantized and self.quantization_config.dequantize:
- return [
- WeightConverter(
- source_patterns=["down_proj_blocks", "down_proj_scales"],
- target_patterns=r"down_proj$",
- operations=[Mxfp4Dequantize(self)],
- ),
- WeightConverter(
- source_patterns=["gate_up_proj_blocks", "gate_up_proj_scales"],
- target_patterns=["gate_up_proj$"],
- operations=[Mxfp4Dequantize(self)],
- ),
- ]
- return [
- WeightConverter(
- source_patterns=["gate_up_proj_blocks", "gate_up_proj_scales"],
- target_patterns=r"gate_up_proj$",
- operations=[Mxfp4Deserialize(self)],
- ),
- WeightConverter(
- source_patterns=["down_proj_blocks", "down_proj_scales"],
- target_patterns=r"down_proj$",
- operations=[Mxfp4Deserialize(self)],
- ),
- ]
|