| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- # Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import warnings
- from ..models.auto.configuration_auto import AutoConfig
- from ..utils import logging
- from ..utils.quantization_config import (
- AqlmConfig,
- AutoRoundConfig,
- AwqConfig,
- BitNetQuantConfig,
- BitsAndBytesConfig,
- CompressedTensorsConfig,
- EetqConfig,
- FbgemmFp8Config,
- FineGrainedFP8Config,
- FourOverSixConfig,
- FPQuantConfig,
- GPTQConfig,
- HiggsConfig,
- HqqConfig,
- MetalConfig,
- Mxfp4Config,
- QuantizationConfigMixin,
- QuantizationMethod,
- QuantoConfig,
- QuarkConfig,
- SinqConfig,
- SpQRConfig,
- TorchAoConfig,
- VptqConfig,
- )
- from .base import HfQuantizer
- from .quantizer_aqlm import AqlmHfQuantizer
- from .quantizer_auto_round import AutoRoundQuantizer
- from .quantizer_awq import AwqQuantizer
- from .quantizer_bitnet import BitNetHfQuantizer
- from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
- from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
- from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
- from .quantizer_eetq import EetqHfQuantizer
- from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
- from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
- from .quantizer_fouroversix import FourOverSixHfQuantizer
- from .quantizer_fp_quant import FPQuantHfQuantizer
- from .quantizer_gptq import GptqHfQuantizer
- from .quantizer_higgs import HiggsHfQuantizer
- from .quantizer_hqq import HqqHfQuantizer
- from .quantizer_metal import MetalHfQuantizer
- from .quantizer_mxfp4 import Mxfp4HfQuantizer
- from .quantizer_quanto import QuantoHfQuantizer
- from .quantizer_quark import QuarkHfQuantizer
- from .quantizer_sinq import SinqHfQuantizer
- from .quantizer_spqr import SpQRHfQuantizer
- from .quantizer_torchao import TorchAoHfQuantizer
- from .quantizer_vptq import VptqHfQuantizer
- AUTO_QUANTIZER_MAPPING = {
- "awq": AwqQuantizer,
- "bitsandbytes_4bit": Bnb4BitHfQuantizer,
- "bitsandbytes_8bit": Bnb8BitHfQuantizer,
- "gptq": GptqHfQuantizer,
- "aqlm": AqlmHfQuantizer,
- "quanto": QuantoHfQuantizer,
- "quark": QuarkHfQuantizer,
- "fouroversix": FourOverSixHfQuantizer,
- "fp_quant": FPQuantHfQuantizer,
- "eetq": EetqHfQuantizer,
- "higgs": HiggsHfQuantizer,
- "hqq": HqqHfQuantizer,
- "compressed-tensors": CompressedTensorsHfQuantizer,
- "fbgemm_fp8": FbgemmFp8HfQuantizer,
- "torchao": TorchAoHfQuantizer,
- "bitnet": BitNetHfQuantizer,
- "vptq": VptqHfQuantizer,
- "spqr": SpQRHfQuantizer,
- "fp8": FineGrainedFP8HfQuantizer,
- "auto-round": AutoRoundQuantizer,
- "mxfp4": Mxfp4HfQuantizer,
- "metal": MetalHfQuantizer,
- "sinq": SinqHfQuantizer,
- }
- AUTO_QUANTIZATION_CONFIG_MAPPING = {
- "awq": AwqConfig,
- "bitsandbytes_4bit": BitsAndBytesConfig,
- "bitsandbytes_8bit": BitsAndBytesConfig,
- "eetq": EetqConfig,
- "gptq": GPTQConfig,
- "aqlm": AqlmConfig,
- "quanto": QuantoConfig,
- "quark": QuarkConfig,
- "fouroversix": FourOverSixConfig,
- "fp_quant": FPQuantConfig,
- "hqq": HqqConfig,
- "compressed-tensors": CompressedTensorsConfig,
- "fbgemm_fp8": FbgemmFp8Config,
- "higgs": HiggsConfig,
- "torchao": TorchAoConfig,
- "bitnet": BitNetQuantConfig,
- "vptq": VptqConfig,
- "spqr": SpQRConfig,
- "fp8": FineGrainedFP8Config,
- "auto-round": AutoRoundConfig,
- "mxfp4": Mxfp4Config,
- "metal": MetalConfig,
- "sinq": SinqConfig,
- }
- LOADING_ATTRIBUTES_CONFIG_TYPES = (
- GPTQConfig,
- AwqConfig,
- AutoRoundConfig,
- FbgemmFp8Config,
- CompressedTensorsConfig,
- Mxfp4Config,
- MetalConfig,
- FineGrainedFP8Config,
- )
- logger = logging.get_logger(__name__)
- class AutoQuantizationConfig:
- """
- The Auto-HF quantization config class that takes care of automatically dispatching to the correct
- quantization config given a quantization config stored in a dictionary.
- """
- @classmethod
- def from_dict(cls, quantization_config_dict: dict):
- quant_method = quantization_config_dict.get("quant_method")
- # We need a special care for bnb models to make sure everything is BC ..
- if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
- suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
- quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
- elif quant_method is None:
- raise ValueError(
- "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
- )
- if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
- raise ValueError(
- f"Unknown quantization type, got {quant_method} - supported types are:"
- f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
- )
- target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
- return target_cls.from_dict(quantization_config_dict)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
- if getattr(model_config, "quantization_config", None) is None:
- raise ValueError(
- f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
- )
- quantization_config_dict = model_config.quantization_config
- quantization_config = cls.from_dict(quantization_config_dict)
- # Update with potential kwargs that are passed through from_pretrained.
- quantization_config.update(**kwargs)
- return quantization_config
- class AutoHfQuantizer:
- """
- The Auto-HF quantizer class that takes care of automatically instantiating to the correct
- `HfQuantizer` given the `QuantizationConfig`.
- """
- @classmethod
- def from_config(cls, quantization_config: QuantizationConfigMixin | dict, **kwargs):
- # Convert it to a QuantizationConfig if the q_config is a dict
- if isinstance(quantization_config, dict):
- quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
- quant_method = quantization_config.quant_method
- # Again, we need a special care for bnb as we have a single quantization config
- # class for both 4-bit and 8-bit quantization
- if quant_method == QuantizationMethod.BITS_AND_BYTES:
- if not isinstance(quantization_config, BitsAndBytesConfig):
- raise TypeError(
- "Found `quant_method=bitsandbytes` but `quantization_config` is not a `BitsAndBytesConfig`."
- )
- if quantization_config.load_in_8bit:
- quant_method += "_8bit"
- else:
- quant_method += "_4bit"
- if quant_method not in AUTO_QUANTIZER_MAPPING:
- raise ValueError(
- f"Unknown quantization type, got {quant_method} - supported types are:"
- f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
- )
- target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
- return target_cls(quantization_config, **kwargs)
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
- return cls.from_config(quantization_config)
- @classmethod
- def merge_quantization_configs(
- cls,
- quantization_config: dict | QuantizationConfigMixin,
- quantization_config_from_args: QuantizationConfigMixin | None,
- ):
- """
- handles situations where both quantization_config from args and quantization_config from model config are present.
- """
- if quantization_config_from_args is not None:
- warning_msg = (
- "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
- " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
- )
- else:
- warning_msg = ""
- if isinstance(quantization_config, dict):
- # Convert the config based on the type of quantization_config_from_args (e.g., AutoRoundConfig), which takes priority before automatic configuration dispatch.
- if isinstance(quantization_config_from_args, AutoRoundConfig):
- quantization_config = AutoRoundConfig.from_dict(quantization_config)
- else:
- quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
- if (
- quantization_config_from_args is not None
- and quantization_config.__class__.__name__ != quantization_config_from_args.__class__.__name__
- ):
- raise ValueError(
- f"The model is quantized with {quantization_config.__class__.__name__} but you are passing a {quantization_config_from_args.__class__.__name__} config. "
- "Please make sure to pass the same quantization config class to `from_pretrained` with different loading attributes."
- )
- if isinstance(quantization_config, LOADING_ATTRIBUTES_CONFIG_TYPES) and isinstance(
- quantization_config_from_args, LOADING_ATTRIBUTES_CONFIG_TYPES
- ):
- loading_attr_dict = quantization_config_from_args.get_loading_attributes()
- for attr, val in loading_attr_dict.items():
- setattr(quantization_config, attr, val)
- if loading_attr_dict:
- warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
- if warning_msg != "" and not isinstance(quantization_config, (Mxfp4Config, MetalConfig, FineGrainedFP8Config)):
- warnings.warn(warning_msg)
- else:
- # in the case of mxfp4, we don't want to print the warning message, bit confusing for users
- logger.info(warning_msg)
- return quantization_config
- @staticmethod
- def supports_quant_method(quantization_config_dict):
- quant_method = quantization_config_dict.get("quant_method", None)
- if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
- suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
- quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
- elif quant_method is None:
- raise ValueError(
- "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
- )
- if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
- logger.warning(
- f"Unknown quantization type, got {quant_method} - supported types are:"
- f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
- "To remove the warning, you can delete the quantization_config attribute in config.json"
- )
- return False
- return True
- def register_quantization_config(method: str):
- """Register a custom quantization configuration."""
- def register_config_fn(cls):
- if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
- raise ValueError(f"Config '{method}' already registered")
- if not issubclass(cls, QuantizationConfigMixin):
- raise TypeError("Config must extend QuantizationConfigMixin")
- AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
- return cls
- return register_config_fn
- def register_quantizer(name: str):
- """Register a custom quantizer."""
- def register_quantizer_fn(cls):
- if name in AUTO_QUANTIZER_MAPPING:
- raise ValueError(f"Quantizer '{name}' already registered")
- if not issubclass(cls, HfQuantizer):
- raise TypeError("Quantizer must extend HfQuantizer")
- AUTO_QUANTIZER_MAPPING[name] = cls
- return cls
- return register_quantizer_fn
- def get_hf_quantizer(config, quantization_config, device_map, weights_only, user_agent):
- pre_quantized = hasattr(config, "quantization_config")
- if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
- pre_quantized = False
- if pre_quantized or quantization_config is not None:
- if pre_quantized:
- config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
- config.quantization_config, quantization_config
- )
- else:
- config.quantization_config = quantization_config
- hf_quantizer = AutoHfQuantizer.from_config(
- config.quantization_config,
- pre_quantized=pre_quantized,
- )
- else:
- hf_quantizer = None
- if hf_quantizer is not None:
- hf_quantizer.validate_environment(
- device_map=device_map,
- weights_only=weights_only,
- )
- device_map = hf_quantizer.update_device_map(device_map)
- config = hf_quantizer.update_tp_plan(config)
- config = hf_quantizer.update_ep_plan(config)
- # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
- if not getattr(hf_quantizer.quantization_config, "dequantize", False):
- quant_method = hf_quantizer.quantization_config.quant_method
- user_agent["quant"] = getattr(quant_method, "value", quant_method)
- return hf_quantizer, config, device_map
|