metal_quantization.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Metal affine quantization integration for transformers.
  16. This module provides:
  17. - ``MetalLinear``: a drop-in replacement for ``nn.Linear`` that stores weights
  18. as affine-quantized uint32 packed tensors and uses the ``quantization-mlx``
  19. Metal kernels for the forward pass.
  20. - ``replace_with_metal_linear``: walks a model and swaps every eligible
  21. ``nn.Linear`` with ``MetalLinear``.
  22. - ``MetalQuantize`` / ``MetalDequantize``: weight conversion operations that
  23. participate in the new ``WeightConverter`` pipeline.
  24. Weight layout (transposed, matching ``affine_qmm_t``):
  25. - ``weight``: ``[N, K_packed]`` (``uint32``) -- K is the packed dimension.
  26. - ``scales``: ``[N, K // group_size]`` (``float16 / bfloat16``)
  27. - ``qbiases``: ``[N, K // group_size]`` (same dtype as scales)
  28. The kernel call is ``affine_qmm_t(x, weight, scales, qbiases, group_size, bits)``
  29. which computes ``y = x @ dequant(weight).T``, identical to ``nn.Linear``.
  30. """
  31. from ..core_model_loading import ConversionOps, _IdentityOp
  32. from ..quantizers.quantizers_utils import should_convert_module
  33. from ..utils import is_torch_available, logging
  34. if is_torch_available():
  35. import torch
  36. import torch.nn as nn
  37. logger = logging.get_logger(__name__)
  38. _metal_kernel = None
  39. def _get_metal_kernel():
  40. """Lazily load the quantization-mlx kernel from Hugging Face Hub."""
  41. global _metal_kernel
  42. if _metal_kernel is None:
  43. try:
  44. from .hub_kernels import get_kernel
  45. _metal_kernel = get_kernel("kernels-community/mlx-quantization-metal-kernels")
  46. except Exception as e:
  47. raise ImportError(
  48. f"Failed to load the quantization-mlx kernel from the Hub: {e}. "
  49. "Make sure you have `kernels` installed (`pip install kernels`) "
  50. "and are running on an Apple Silicon machine."
  51. ) from e
  52. return _metal_kernel
  53. # ---------------------------------------------------------------------------
  54. # MetalLinear -- the quantized nn.Linear replacement
  55. # ---------------------------------------------------------------------------
  56. class MetalLinear(nn.Linear):
  57. """
  58. A quantized linear layer that stores weights in affine uint32 packed format
  59. and uses the ``quantization-mlx`` Metal kernels for the forward pass.
  60. Parameters match ``nn.Linear`` with additional quantization metadata.
  61. """
  62. def __init__(
  63. self,
  64. in_features: int,
  65. out_features: int,
  66. bias: bool = False,
  67. dtype=torch.uint32,
  68. bits: int = 4,
  69. group_size: int = 128,
  70. ):
  71. nn.Module.__init__(self)
  72. self.in_features = in_features
  73. self.out_features = out_features
  74. self.bits = bits
  75. self.group_size = group_size
  76. elems_per_int = 32 // bits
  77. k_packed = in_features // elems_per_int
  78. n_groups = in_features // group_size
  79. if dtype == torch.uint32:
  80. self.weight = nn.Parameter(torch.zeros(out_features, k_packed, dtype=torch.uint32), requires_grad=False)
  81. else:
  82. self.weight = nn.Parameter(torch.zeros(out_features, in_features, dtype=dtype), requires_grad=False)
  83. scales_dtype = torch.float32 if dtype == torch.uint32 else None
  84. self.scales = nn.Parameter(torch.zeros(out_features, n_groups, dtype=scales_dtype), requires_grad=False)
  85. self.qbiases = nn.Parameter(torch.zeros(out_features, n_groups, dtype=scales_dtype), requires_grad=False)
  86. if bias:
  87. self.bias = nn.Parameter(torch.zeros(out_features))
  88. else:
  89. self.register_parameter("bias", None)
  90. def forward(self, input: torch.Tensor) -> torch.Tensor:
  91. if self.weight.dtype != torch.uint32:
  92. return nn.functional.linear(input, self.weight, self.bias)
  93. kernel = _get_metal_kernel()
  94. output = kernel.affine_qmm_t(
  95. input,
  96. self.weight,
  97. self.scales.to(input.dtype),
  98. self.qbiases.to(input.dtype),
  99. self.group_size,
  100. self.bits,
  101. )
  102. if self.bias is not None:
  103. output = output + self.bias
  104. return output
  105. def replace_with_metal_linear(
  106. model,
  107. modules_to_not_convert: list[str] | None = None,
  108. quantization_config=None,
  109. pre_quantized: bool = False,
  110. ):
  111. """
  112. Replace every eligible ``nn.Linear`` with ``MetalLinear``.
  113. Args:
  114. model: the ``PreTrainedModel`` (on the meta device at this point).
  115. modules_to_not_convert: module names to leave untouched.
  116. quantization_config: the ``MetalConfig`` instance.
  117. pre_quantized: ``True`` when loading from a quantized checkpoint.
  118. """
  119. if quantization_config.dequantize:
  120. return model
  121. bits = quantization_config.bits
  122. group_size = quantization_config.group_size
  123. has_been_replaced = False
  124. for module_name, module in model.named_modules():
  125. if not should_convert_module(module_name, modules_to_not_convert):
  126. continue
  127. if isinstance(module, nn.Linear):
  128. module_kwargs = {} if pre_quantized else {"dtype": None}
  129. new_module = MetalLinear(
  130. in_features=module.in_features,
  131. out_features=module.out_features,
  132. bias=module.bias is not None,
  133. bits=bits,
  134. group_size=group_size,
  135. **module_kwargs,
  136. )
  137. model.set_submodule(module_name, new_module)
  138. has_been_replaced = True
  139. if not has_been_replaced:
  140. logger.warning(
  141. "You are loading a model with Metal quantization but no nn.Linear modules were found. "
  142. "Please double check your model architecture."
  143. )
  144. return model
  145. def _affine_quantize_tensor(weight: torch.Tensor, group_size: int, bits: int):
  146. """
  147. Quantize a 2-D float weight ``[N, K]`` into packed uint32 + scales + biases.
  148. Returns ``(w_packed, scales, biases)`` with:
  149. - ``w_packed``: ``[N, K // (32 // bits)]`` uint32
  150. - ``scales``: ``[N, K // group_size]`` float32/float16/bfloat16
  151. - ``biases``: ``[N, K // group_size]`` float32/float16/bfloat16
  152. """
  153. N, K = weight.shape
  154. elems_per_int = 32 // bits
  155. max_val = (1 << bits) - 1
  156. n_groups = K // group_size
  157. w_grouped = weight.float().reshape(N, n_groups, group_size)
  158. w_min = w_grouped.min(dim=-1).values # [N, n_groups]
  159. w_max = w_grouped.max(dim=-1).values
  160. scales = ((w_max - w_min) / max_val).clamp(min=1e-8)
  161. biases = w_min
  162. w_int = (w_grouped - biases.unsqueeze(-1)) / scales.unsqueeze(-1)
  163. w_int = w_int.round().clamp(0, max_val).to(torch.int32).reshape(N, K)
  164. # Pack into uint32
  165. k_packed = K // elems_per_int
  166. w_packed = torch.zeros(N, k_packed, dtype=torch.int32, device=weight.device)
  167. for i in range(elems_per_int):
  168. w_packed |= w_int[:, i::elems_per_int] << (bits * i)
  169. return w_packed.to(torch.uint32), scales, biases
  170. def _affine_dequantize_tensor(
  171. w_packed: torch.Tensor, scales: torch.Tensor, biases: torch.Tensor, group_size: int, bits: int
  172. ):
  173. """
  174. Dequantize a packed uint32 weight ``[N, K_packed]`` back to float.
  175. Returns a ``[N, K]`` float32 tensor.
  176. """
  177. N = w_packed.shape[0]
  178. elems_per_int = 32 // bits
  179. max_val = (1 << bits) - 1
  180. K = w_packed.shape[1] * elems_per_int
  181. w_packed_i = w_packed.to(torch.int32)
  182. w_flat = torch.zeros(N, K, dtype=torch.float32, device=w_packed.device)
  183. for i in range(elems_per_int):
  184. w_flat[:, i::elems_per_int] = ((w_packed_i >> (bits * i)) & max_val).float()
  185. w_grouped = w_flat.reshape(N, -1, group_size)
  186. w_deq = w_grouped * scales.float().unsqueeze(-1) + biases.float().unsqueeze(-1)
  187. return w_deq.reshape(N, K)
  188. class MetalQuantize(ConversionOps):
  189. """
  190. Quantize a full-precision weight tensor into (weight, scales, qbiases).
  191. Used during quantize-on-the-fly. The float ``weight`` is replaced in-place
  192. by the packed uint32 tensor.
  193. """
  194. def __init__(self, hf_quantizer):
  195. self.hf_quantizer = hf_quantizer
  196. def convert(self, input_dict: dict, **kwargs) -> dict:
  197. target_key, value = next(iter(input_dict.items()))
  198. value = value[0] if isinstance(value, list) else value
  199. bits = self.hf_quantizer.quantization_config.bits
  200. group_size = self.hf_quantizer.quantization_config.group_size
  201. w_packed, scales, biases = _affine_quantize_tensor(value, group_size, bits)
  202. base = target_key.rsplit(".", 1)[0] if "." in target_key else ""
  203. scale_key = f"{base}.scales" if base else "scales"
  204. bias_key = f"{base}.qbiases" if base else "qbiases"
  205. orig_dtype = value.dtype
  206. return {
  207. target_key: w_packed,
  208. scale_key: scales.to(orig_dtype),
  209. bias_key: biases.to(orig_dtype),
  210. }
  211. class MetalDequantize(ConversionOps):
  212. """
  213. Dequantize (weight, scales, qbiases) back to a full-precision tensor.
  214. Used when ``dequantize=True`` is set in the config to fall back to a normal
  215. ``nn.Linear`` on devices without MPS.
  216. """
  217. def __init__(self, hf_quantizer):
  218. self.hf_quantizer = hf_quantizer
  219. def convert(self, input_dict: dict, full_layer_name: str | None = None, **kwargs) -> dict:
  220. bits = self.hf_quantizer.quantization_config.bits
  221. group_size = self.hf_quantizer.quantization_config.group_size
  222. if len(input_dict) < 2:
  223. return {full_layer_name: input_dict["weight$"]}
  224. quantized = input_dict["weight$"][0]
  225. scales = input_dict["scales"][0]
  226. qbiases = input_dict["qbiases"][0]
  227. w_deq = _affine_dequantize_tensor(quantized, scales, qbiases, group_size, bits)
  228. return {full_layer_name: w_deq.to(scales.dtype)}
  229. @property
  230. def reverse_op(self) -> "ConversionOps":
  231. return _IdentityOp()