auto.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. # Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import warnings
  16. from ..models.auto.configuration_auto import AutoConfig
  17. from ..utils import logging
  18. from ..utils.quantization_config import (
  19. AqlmConfig,
  20. AutoRoundConfig,
  21. AwqConfig,
  22. BitNetQuantConfig,
  23. BitsAndBytesConfig,
  24. CompressedTensorsConfig,
  25. EetqConfig,
  26. FbgemmFp8Config,
  27. FineGrainedFP8Config,
  28. FourOverSixConfig,
  29. FPQuantConfig,
  30. GPTQConfig,
  31. HiggsConfig,
  32. HqqConfig,
  33. MetalConfig,
  34. Mxfp4Config,
  35. QuantizationConfigMixin,
  36. QuantizationMethod,
  37. QuantoConfig,
  38. QuarkConfig,
  39. SinqConfig,
  40. SpQRConfig,
  41. TorchAoConfig,
  42. VptqConfig,
  43. )
  44. from .base import HfQuantizer
  45. from .quantizer_aqlm import AqlmHfQuantizer
  46. from .quantizer_auto_round import AutoRoundQuantizer
  47. from .quantizer_awq import AwqQuantizer
  48. from .quantizer_bitnet import BitNetHfQuantizer
  49. from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
  50. from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
  51. from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer
  52. from .quantizer_eetq import EetqHfQuantizer
  53. from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
  54. from .quantizer_finegrained_fp8 import FineGrainedFP8HfQuantizer
  55. from .quantizer_fouroversix import FourOverSixHfQuantizer
  56. from .quantizer_fp_quant import FPQuantHfQuantizer
  57. from .quantizer_gptq import GptqHfQuantizer
  58. from .quantizer_higgs import HiggsHfQuantizer
  59. from .quantizer_hqq import HqqHfQuantizer
  60. from .quantizer_metal import MetalHfQuantizer
  61. from .quantizer_mxfp4 import Mxfp4HfQuantizer
  62. from .quantizer_quanto import QuantoHfQuantizer
  63. from .quantizer_quark import QuarkHfQuantizer
  64. from .quantizer_sinq import SinqHfQuantizer
  65. from .quantizer_spqr import SpQRHfQuantizer
  66. from .quantizer_torchao import TorchAoHfQuantizer
  67. from .quantizer_vptq import VptqHfQuantizer
  68. AUTO_QUANTIZER_MAPPING = {
  69. "awq": AwqQuantizer,
  70. "bitsandbytes_4bit": Bnb4BitHfQuantizer,
  71. "bitsandbytes_8bit": Bnb8BitHfQuantizer,
  72. "gptq": GptqHfQuantizer,
  73. "aqlm": AqlmHfQuantizer,
  74. "quanto": QuantoHfQuantizer,
  75. "quark": QuarkHfQuantizer,
  76. "fouroversix": FourOverSixHfQuantizer,
  77. "fp_quant": FPQuantHfQuantizer,
  78. "eetq": EetqHfQuantizer,
  79. "higgs": HiggsHfQuantizer,
  80. "hqq": HqqHfQuantizer,
  81. "compressed-tensors": CompressedTensorsHfQuantizer,
  82. "fbgemm_fp8": FbgemmFp8HfQuantizer,
  83. "torchao": TorchAoHfQuantizer,
  84. "bitnet": BitNetHfQuantizer,
  85. "vptq": VptqHfQuantizer,
  86. "spqr": SpQRHfQuantizer,
  87. "fp8": FineGrainedFP8HfQuantizer,
  88. "auto-round": AutoRoundQuantizer,
  89. "mxfp4": Mxfp4HfQuantizer,
  90. "metal": MetalHfQuantizer,
  91. "sinq": SinqHfQuantizer,
  92. }
  93. AUTO_QUANTIZATION_CONFIG_MAPPING = {
  94. "awq": AwqConfig,
  95. "bitsandbytes_4bit": BitsAndBytesConfig,
  96. "bitsandbytes_8bit": BitsAndBytesConfig,
  97. "eetq": EetqConfig,
  98. "gptq": GPTQConfig,
  99. "aqlm": AqlmConfig,
  100. "quanto": QuantoConfig,
  101. "quark": QuarkConfig,
  102. "fouroversix": FourOverSixConfig,
  103. "fp_quant": FPQuantConfig,
  104. "hqq": HqqConfig,
  105. "compressed-tensors": CompressedTensorsConfig,
  106. "fbgemm_fp8": FbgemmFp8Config,
  107. "higgs": HiggsConfig,
  108. "torchao": TorchAoConfig,
  109. "bitnet": BitNetQuantConfig,
  110. "vptq": VptqConfig,
  111. "spqr": SpQRConfig,
  112. "fp8": FineGrainedFP8Config,
  113. "auto-round": AutoRoundConfig,
  114. "mxfp4": Mxfp4Config,
  115. "metal": MetalConfig,
  116. "sinq": SinqConfig,
  117. }
  118. LOADING_ATTRIBUTES_CONFIG_TYPES = (
  119. GPTQConfig,
  120. AwqConfig,
  121. AutoRoundConfig,
  122. FbgemmFp8Config,
  123. CompressedTensorsConfig,
  124. Mxfp4Config,
  125. MetalConfig,
  126. FineGrainedFP8Config,
  127. )
  128. logger = logging.get_logger(__name__)
  129. class AutoQuantizationConfig:
  130. """
  131. The Auto-HF quantization config class that takes care of automatically dispatching to the correct
  132. quantization config given a quantization config stored in a dictionary.
  133. """
  134. @classmethod
  135. def from_dict(cls, quantization_config_dict: dict):
  136. quant_method = quantization_config_dict.get("quant_method")
  137. # We need a special care for bnb models to make sure everything is BC ..
  138. if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
  139. suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
  140. quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
  141. elif quant_method is None:
  142. raise ValueError(
  143. "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
  144. )
  145. if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
  146. raise ValueError(
  147. f"Unknown quantization type, got {quant_method} - supported types are:"
  148. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  149. )
  150. target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
  151. return target_cls.from_dict(quantization_config_dict)
  152. @classmethod
  153. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  154. model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  155. if getattr(model_config, "quantization_config", None) is None:
  156. raise ValueError(
  157. f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
  158. )
  159. quantization_config_dict = model_config.quantization_config
  160. quantization_config = cls.from_dict(quantization_config_dict)
  161. # Update with potential kwargs that are passed through from_pretrained.
  162. quantization_config.update(**kwargs)
  163. return quantization_config
  164. class AutoHfQuantizer:
  165. """
  166. The Auto-HF quantizer class that takes care of automatically instantiating to the correct
  167. `HfQuantizer` given the `QuantizationConfig`.
  168. """
  169. @classmethod
  170. def from_config(cls, quantization_config: QuantizationConfigMixin | dict, **kwargs):
  171. # Convert it to a QuantizationConfig if the q_config is a dict
  172. if isinstance(quantization_config, dict):
  173. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  174. quant_method = quantization_config.quant_method
  175. # Again, we need a special care for bnb as we have a single quantization config
  176. # class for both 4-bit and 8-bit quantization
  177. if quant_method == QuantizationMethod.BITS_AND_BYTES:
  178. if not isinstance(quantization_config, BitsAndBytesConfig):
  179. raise TypeError(
  180. "Found `quant_method=bitsandbytes` but `quantization_config` is not a `BitsAndBytesConfig`."
  181. )
  182. if quantization_config.load_in_8bit:
  183. quant_method += "_8bit"
  184. else:
  185. quant_method += "_4bit"
  186. if quant_method not in AUTO_QUANTIZER_MAPPING:
  187. raise ValueError(
  188. f"Unknown quantization type, got {quant_method} - supported types are:"
  189. f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
  190. )
  191. target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
  192. return target_cls(quantization_config, **kwargs)
  193. @classmethod
  194. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  195. quantization_config = AutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
  196. return cls.from_config(quantization_config)
  197. @classmethod
  198. def merge_quantization_configs(
  199. cls,
  200. quantization_config: dict | QuantizationConfigMixin,
  201. quantization_config_from_args: QuantizationConfigMixin | None,
  202. ):
  203. """
  204. handles situations where both quantization_config from args and quantization_config from model config are present.
  205. """
  206. if quantization_config_from_args is not None:
  207. warning_msg = (
  208. "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
  209. " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
  210. )
  211. else:
  212. warning_msg = ""
  213. if isinstance(quantization_config, dict):
  214. # Convert the config based on the type of quantization_config_from_args (e.g., AutoRoundConfig), which takes priority before automatic configuration dispatch.
  215. if isinstance(quantization_config_from_args, AutoRoundConfig):
  216. quantization_config = AutoRoundConfig.from_dict(quantization_config)
  217. else:
  218. quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
  219. if (
  220. quantization_config_from_args is not None
  221. and quantization_config.__class__.__name__ != quantization_config_from_args.__class__.__name__
  222. ):
  223. raise ValueError(
  224. f"The model is quantized with {quantization_config.__class__.__name__} but you are passing a {quantization_config_from_args.__class__.__name__} config. "
  225. "Please make sure to pass the same quantization config class to `from_pretrained` with different loading attributes."
  226. )
  227. if isinstance(quantization_config, LOADING_ATTRIBUTES_CONFIG_TYPES) and isinstance(
  228. quantization_config_from_args, LOADING_ATTRIBUTES_CONFIG_TYPES
  229. ):
  230. loading_attr_dict = quantization_config_from_args.get_loading_attributes()
  231. for attr, val in loading_attr_dict.items():
  232. setattr(quantization_config, attr, val)
  233. if loading_attr_dict:
  234. warning_msg += f"However, loading attributes (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
  235. if warning_msg != "" and not isinstance(quantization_config, (Mxfp4Config, MetalConfig, FineGrainedFP8Config)):
  236. warnings.warn(warning_msg)
  237. else:
  238. # in the case of mxfp4, we don't want to print the warning message, bit confusing for users
  239. logger.info(warning_msg)
  240. return quantization_config
  241. @staticmethod
  242. def supports_quant_method(quantization_config_dict):
  243. quant_method = quantization_config_dict.get("quant_method", None)
  244. if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
  245. suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
  246. quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
  247. elif quant_method is None:
  248. raise ValueError(
  249. "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
  250. )
  251. if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
  252. logger.warning(
  253. f"Unknown quantization type, got {quant_method} - supported types are:"
  254. f" {list(AUTO_QUANTIZER_MAPPING.keys())}. Hence, we will skip the quantization. "
  255. "To remove the warning, you can delete the quantization_config attribute in config.json"
  256. )
  257. return False
  258. return True
  259. def register_quantization_config(method: str):
  260. """Register a custom quantization configuration."""
  261. def register_config_fn(cls):
  262. if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
  263. raise ValueError(f"Config '{method}' already registered")
  264. if not issubclass(cls, QuantizationConfigMixin):
  265. raise TypeError("Config must extend QuantizationConfigMixin")
  266. AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
  267. return cls
  268. return register_config_fn
  269. def register_quantizer(name: str):
  270. """Register a custom quantizer."""
  271. def register_quantizer_fn(cls):
  272. if name in AUTO_QUANTIZER_MAPPING:
  273. raise ValueError(f"Quantizer '{name}' already registered")
  274. if not issubclass(cls, HfQuantizer):
  275. raise TypeError("Quantizer must extend HfQuantizer")
  276. AUTO_QUANTIZER_MAPPING[name] = cls
  277. return cls
  278. return register_quantizer_fn
  279. def get_hf_quantizer(config, quantization_config, device_map, weights_only, user_agent):
  280. pre_quantized = hasattr(config, "quantization_config")
  281. if pre_quantized and not AutoHfQuantizer.supports_quant_method(config.quantization_config):
  282. pre_quantized = False
  283. if pre_quantized or quantization_config is not None:
  284. if pre_quantized:
  285. config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
  286. config.quantization_config, quantization_config
  287. )
  288. else:
  289. config.quantization_config = quantization_config
  290. hf_quantizer = AutoHfQuantizer.from_config(
  291. config.quantization_config,
  292. pre_quantized=pre_quantized,
  293. )
  294. else:
  295. hf_quantizer = None
  296. if hf_quantizer is not None:
  297. hf_quantizer.validate_environment(
  298. device_map=device_map,
  299. weights_only=weights_only,
  300. )
  301. device_map = hf_quantizer.update_device_map(device_map)
  302. config = hf_quantizer.update_tp_plan(config)
  303. config = hf_quantizer.update_ep_plan(config)
  304. # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
  305. if not getattr(hf_quantizer.quantization_config, "dequantize", False):
  306. quant_method = hf_quantizer.quantization_config.quant_method
  307. user_agent["quant"] = getattr(quant_method, "value", quant_method)
  308. return hf_quantizer, config, device_map