quantizer_metal.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Any
  15. from ..utils import is_kernels_available, is_torch_available, logging
  16. from .base import HfQuantizer
  17. from .quantizers_utils import get_module_from_name
  18. if is_torch_available():
  19. import torch
  20. if TYPE_CHECKING:
  21. from ..modeling_utils import PreTrainedModel
  22. from ..utils.quantization_config import MetalConfig
  23. logger = logging.get_logger(__name__)
  24. class MetalHfQuantizer(HfQuantizer):
  25. """
  26. Quantizer for Metal affine quantization on Apple Silicon (MPS) devices.
  27. Uses the ``quantization-mlx`` Metal kernels from the Hub to pack weights into
  28. low-bit (2/4/8) uint32 tensors with per-group scales and biases, and performs
  29. fused dequant + matmul in the forward pass.
  30. """
  31. requires_calibration = False
  32. quantization_config: "MetalConfig"
  33. def __init__(self, quantization_config, **kwargs):
  34. super().__init__(quantization_config, **kwargs)
  35. def validate_environment(self, *args, **kwargs):
  36. if self.quantization_config.dequantize:
  37. return
  38. if not torch.backends.mps.is_available():
  39. if self.pre_quantized:
  40. logger.warning_once(
  41. "Metal quantization requires an Apple Silicon GPU (MPS), but none is available. "
  42. "We will default to dequantizing the model to the original dtype."
  43. )
  44. self.quantization_config.dequantize = True
  45. return
  46. else:
  47. raise RuntimeError("Metal quantization requires an Apple Silicon GPU (MPS). No MPS device found.")
  48. if not is_kernels_available():
  49. raise ImportError("Metal quantization requires kernels: `pip install kernels`")
  50. device_map = kwargs.get("device_map")
  51. if device_map is None:
  52. logger.warning_once(
  53. "You have loaded a Metal quantized model on CPU and have an MPS device available. "
  54. "Set device_map='mps' to use the Metal kernels."
  55. )
  56. elif isinstance(device_map, dict):
  57. if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
  58. raise ValueError(
  59. "Metal quantization on the fly does not support CPU or disk in the device_map. "
  60. "Please use a pre-quantized checkpoint or remove CPU/disk from device_map."
  61. )
  62. def update_device_map(self, device_map: dict[str, Any] | None) -> dict[str, Any] | None:
  63. if device_map is None:
  64. device_map = {"": "mps"}
  65. return device_map
  66. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  67. from ..integrations.metal_quantization import MetalLinear
  68. module, tensor_name = get_module_from_name(model, param_name)
  69. if isinstance(module, MetalLinear):
  70. if self.pre_quantized or tensor_name != "weight":
  71. return False
  72. return True
  73. return False
  74. def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
  75. from ..integrations.metal_quantization import replace_with_metal_linear
  76. self.modules_to_not_convert = self.get_modules_to_not_convert(
  77. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  78. )
  79. model = replace_with_metal_linear(
  80. model,
  81. modules_to_not_convert=self.modules_to_not_convert,
  82. quantization_config=self.quantization_config,
  83. pre_quantized=self.pre_quantized,
  84. )
  85. def is_serializable(self):
  86. return True
  87. @property
  88. def is_trainable(self) -> bool:
  89. return False
  90. def get_quantize_ops(self):
  91. from ..integrations.metal_quantization import MetalQuantize
  92. return MetalQuantize(self)
  93. def get_weight_conversions(self):
  94. from ..core_model_loading import WeightConverter
  95. from ..integrations.metal_quantization import MetalDequantize
  96. if self.pre_quantized and self.quantization_config.dequantize:
  97. return [
  98. WeightConverter(
  99. source_patterns=["weight$", "scales", "qbiases"],
  100. target_patterns="weight",
  101. operations=[MetalDequantize(self)],
  102. )
  103. ]
  104. return []