| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Metal affine quantization integration for transformers.
- This module provides:
- - ``MetalLinear``: a drop-in replacement for ``nn.Linear`` that stores weights
- as affine-quantized uint32 packed tensors and uses the ``quantization-mlx``
- Metal kernels for the forward pass.
- - ``replace_with_metal_linear``: walks a model and swaps every eligible
- ``nn.Linear`` with ``MetalLinear``.
- - ``MetalQuantize`` / ``MetalDequantize``: weight conversion operations that
- participate in the new ``WeightConverter`` pipeline.
- Weight layout (transposed, matching ``affine_qmm_t``):
- - ``weight``: ``[N, K_packed]`` (``uint32``) -- K is the packed dimension.
- - ``scales``: ``[N, K // group_size]`` (``float16 / bfloat16``)
- - ``qbiases``: ``[N, K // group_size]`` (same dtype as scales)
- The kernel call is ``affine_qmm_t(x, weight, scales, qbiases, group_size, bits)``
- which computes ``y = x @ dequant(weight).T``, identical to ``nn.Linear``.
- """
- from ..core_model_loading import ConversionOps, _IdentityOp
- from ..quantizers.quantizers_utils import should_convert_module
- from ..utils import is_torch_available, logging
- if is_torch_available():
- import torch
- import torch.nn as nn
- logger = logging.get_logger(__name__)
- _metal_kernel = None
- def _get_metal_kernel():
- """Lazily load the quantization-mlx kernel from Hugging Face Hub."""
- global _metal_kernel
- if _metal_kernel is None:
- try:
- from .hub_kernels import get_kernel
- _metal_kernel = get_kernel("kernels-community/mlx-quantization-metal-kernels")
- except Exception as e:
- raise ImportError(
- f"Failed to load the quantization-mlx kernel from the Hub: {e}. "
- "Make sure you have `kernels` installed (`pip install kernels`) "
- "and are running on an Apple Silicon machine."
- ) from e
- return _metal_kernel
- # ---------------------------------------------------------------------------
- # MetalLinear -- the quantized nn.Linear replacement
- # ---------------------------------------------------------------------------
- class MetalLinear(nn.Linear):
- """
- A quantized linear layer that stores weights in affine uint32 packed format
- and uses the ``quantization-mlx`` Metal kernels for the forward pass.
- Parameters match ``nn.Linear`` with additional quantization metadata.
- """
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = False,
- dtype=torch.uint32,
- bits: int = 4,
- group_size: int = 128,
- ):
- nn.Module.__init__(self)
- self.in_features = in_features
- self.out_features = out_features
- self.bits = bits
- self.group_size = group_size
- elems_per_int = 32 // bits
- k_packed = in_features // elems_per_int
- n_groups = in_features // group_size
- if dtype == torch.uint32:
- self.weight = nn.Parameter(torch.zeros(out_features, k_packed, dtype=torch.uint32), requires_grad=False)
- else:
- self.weight = nn.Parameter(torch.zeros(out_features, in_features, dtype=dtype), requires_grad=False)
- scales_dtype = torch.float32 if dtype == torch.uint32 else None
- self.scales = nn.Parameter(torch.zeros(out_features, n_groups, dtype=scales_dtype), requires_grad=False)
- self.qbiases = nn.Parameter(torch.zeros(out_features, n_groups, dtype=scales_dtype), requires_grad=False)
- if bias:
- self.bias = nn.Parameter(torch.zeros(out_features))
- else:
- self.register_parameter("bias", None)
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- if self.weight.dtype != torch.uint32:
- return nn.functional.linear(input, self.weight, self.bias)
- kernel = _get_metal_kernel()
- output = kernel.affine_qmm_t(
- input,
- self.weight,
- self.scales.to(input.dtype),
- self.qbiases.to(input.dtype),
- self.group_size,
- self.bits,
- )
- if self.bias is not None:
- output = output + self.bias
- return output
- def replace_with_metal_linear(
- model,
- modules_to_not_convert: list[str] | None = None,
- quantization_config=None,
- pre_quantized: bool = False,
- ):
- """
- Replace every eligible ``nn.Linear`` with ``MetalLinear``.
- Args:
- model: the ``PreTrainedModel`` (on the meta device at this point).
- modules_to_not_convert: module names to leave untouched.
- quantization_config: the ``MetalConfig`` instance.
- pre_quantized: ``True`` when loading from a quantized checkpoint.
- """
- if quantization_config.dequantize:
- return model
- bits = quantization_config.bits
- group_size = quantization_config.group_size
- has_been_replaced = False
- for module_name, module in model.named_modules():
- if not should_convert_module(module_name, modules_to_not_convert):
- continue
- if isinstance(module, nn.Linear):
- module_kwargs = {} if pre_quantized else {"dtype": None}
- new_module = MetalLinear(
- in_features=module.in_features,
- out_features=module.out_features,
- bias=module.bias is not None,
- bits=bits,
- group_size=group_size,
- **module_kwargs,
- )
- model.set_submodule(module_name, new_module)
- has_been_replaced = True
- if not has_been_replaced:
- logger.warning(
- "You are loading a model with Metal quantization but no nn.Linear modules were found. "
- "Please double check your model architecture."
- )
- return model
- def _affine_quantize_tensor(weight: torch.Tensor, group_size: int, bits: int):
- """
- Quantize a 2-D float weight ``[N, K]`` into packed uint32 + scales + biases.
- Returns ``(w_packed, scales, biases)`` with:
- - ``w_packed``: ``[N, K // (32 // bits)]`` uint32
- - ``scales``: ``[N, K // group_size]`` float32/float16/bfloat16
- - ``biases``: ``[N, K // group_size]`` float32/float16/bfloat16
- """
- N, K = weight.shape
- elems_per_int = 32 // bits
- max_val = (1 << bits) - 1
- n_groups = K // group_size
- w_grouped = weight.float().reshape(N, n_groups, group_size)
- w_min = w_grouped.min(dim=-1).values # [N, n_groups]
- w_max = w_grouped.max(dim=-1).values
- scales = ((w_max - w_min) / max_val).clamp(min=1e-8)
- biases = w_min
- w_int = (w_grouped - biases.unsqueeze(-1)) / scales.unsqueeze(-1)
- w_int = w_int.round().clamp(0, max_val).to(torch.int32).reshape(N, K)
- # Pack into uint32
- k_packed = K // elems_per_int
- w_packed = torch.zeros(N, k_packed, dtype=torch.int32, device=weight.device)
- for i in range(elems_per_int):
- w_packed |= w_int[:, i::elems_per_int] << (bits * i)
- return w_packed.to(torch.uint32), scales, biases
- def _affine_dequantize_tensor(
- w_packed: torch.Tensor, scales: torch.Tensor, biases: torch.Tensor, group_size: int, bits: int
- ):
- """
- Dequantize a packed uint32 weight ``[N, K_packed]`` back to float.
- Returns a ``[N, K]`` float32 tensor.
- """
- N = w_packed.shape[0]
- elems_per_int = 32 // bits
- max_val = (1 << bits) - 1
- K = w_packed.shape[1] * elems_per_int
- w_packed_i = w_packed.to(torch.int32)
- w_flat = torch.zeros(N, K, dtype=torch.float32, device=w_packed.device)
- for i in range(elems_per_int):
- w_flat[:, i::elems_per_int] = ((w_packed_i >> (bits * i)) & max_val).float()
- w_grouped = w_flat.reshape(N, -1, group_size)
- w_deq = w_grouped * scales.float().unsqueeze(-1) + biases.float().unsqueeze(-1)
- return w_deq.reshape(N, K)
- class MetalQuantize(ConversionOps):
- """
- Quantize a full-precision weight tensor into (weight, scales, qbiases).
- Used during quantize-on-the-fly. The float ``weight`` is replaced in-place
- by the packed uint32 tensor.
- """
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(self, input_dict: dict, **kwargs) -> dict:
- target_key, value = next(iter(input_dict.items()))
- value = value[0] if isinstance(value, list) else value
- bits = self.hf_quantizer.quantization_config.bits
- group_size = self.hf_quantizer.quantization_config.group_size
- w_packed, scales, biases = _affine_quantize_tensor(value, group_size, bits)
- base = target_key.rsplit(".", 1)[0] if "." in target_key else ""
- scale_key = f"{base}.scales" if base else "scales"
- bias_key = f"{base}.qbiases" if base else "qbiases"
- orig_dtype = value.dtype
- return {
- target_key: w_packed,
- scale_key: scales.to(orig_dtype),
- bias_key: biases.to(orig_dtype),
- }
- class MetalDequantize(ConversionOps):
- """
- Dequantize (weight, scales, qbiases) back to a full-precision tensor.
- Used when ``dequantize=True`` is set in the config to fall back to a normal
- ``nn.Linear`` on devices without MPS.
- """
- def __init__(self, hf_quantizer):
- self.hf_quantizer = hf_quantizer
- def convert(self, input_dict: dict, full_layer_name: str | None = None, **kwargs) -> dict:
- bits = self.hf_quantizer.quantization_config.bits
- group_size = self.hf_quantizer.quantization_config.group_size
- if len(input_dict) < 2:
- return {full_layer_name: input_dict["weight$"]}
- quantized = input_dict["weight$"][0]
- scales = input_dict["scales"][0]
- qbiases = input_dict["qbiases"][0]
- w_deq = _affine_dequantize_tensor(quantized, scales, qbiases, group_size, bits)
- return {full_layer_name: w_deq.to(scales.dtype)}
- @property
- def reverse_op(self) -> "ConversionOps":
- return _IdentityOp()
|