| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990 |
- #!/usr/bin/env python
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
- # Modifications Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- import importlib.metadata
- import json
- import os
- from dataclasses import dataclass
- from enum import Enum
- from typing import Any, Optional, Union
- from packaging import version
- from ..utils import (
- is_compressed_tensors_available,
- is_hqq_available,
- is_quark_available,
- is_torch_available,
- is_torchao_available,
- logging,
- )
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- class QuantizationMethod(str, Enum):
- BITS_AND_BYTES = "bitsandbytes"
- GPTQ = "gptq"
- AWQ = "awq"
- AQLM = "aqlm"
- VPTQ = "vptq"
- QUANTO = "quanto"
- EETQ = "eetq"
- HIGGS = "higgs"
- HQQ = "hqq"
- COMPRESSED_TENSORS = "compressed-tensors"
- FBGEMM_FP8 = "fbgemm_fp8"
- TORCHAO = "torchao"
- BITNET = "bitnet"
- SPQR = "spqr"
- FP8 = "fp8"
- QUARK = "quark"
- FPQUANT = "fp_quant"
- AUTOROUND = "auto-round"
- MXFP4 = "mxfp4"
- METAL = "metal"
- FOUR_OVER_SIX = "fouroversix"
- SINQ = "sinq"
- class AwqFormat(str, Enum):
- GEMM = "gemm"
- GEMV = "gemv"
- GEMV_FAST = "gemv_fast"
- LLM_AWQ = "llm-awq"
- class AwqBackend(str, Enum):
- LEGACY_AWQ = "autoawq"
- AUTO = "auto"
- AUTO_TRAINABLE = "auto_trainable"
- MACHETE = "machete"
- MARLIN = "marlin"
- EXLLAMA_V2 = "exllama_v2"
- EXLLAMA_V1 = "exllama_v1"
- GEMM = "gemm"
- GEMM_TRITON = "gemm_triton"
- GEMV = "gemv"
- GEMV_FAST = "gemv_fast"
- TORCH_AWQ = "torch_awq"
- TORCH_FUSED_AWQ = "torch_fused_awq"
- @dataclass
- class QuantizationConfigMixin:
- """
- Mixin class for quantization config
- """
- quant_method: QuantizationMethod
- @classmethod
- def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
- """
- Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
- Args:
- config_dict (`dict[str, Any]`):
- Dictionary that will be used to instantiate the configuration object.
- return_unused_kwargs (`bool`,*optional*, defaults to `False`):
- Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
- `PreTrainedModel`.
- kwargs (`dict[str, Any]`):
- Additional parameters from which to initialize the configuration object.
- Returns:
- [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
- """
- config = cls(**config_dict)
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
- if return_unused_kwargs:
- return config, kwargs
- else:
- return config
- def to_json_file(self, json_file_path: str | os.PathLike):
- """
- Save this instance to a JSON file.
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default
- `QuantizationConfig()` is serialized to JSON file.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- config_dict = self.to_dict()
- json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
- writer.write(json_string)
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return copy.deepcopy(self.__dict__)
- def __iter__(self):
- """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
- yield from copy.deepcopy(self.__dict__).items()
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
- def to_diff_dict(self) -> dict[str, Any]:
- """
- Default behavior: no diffing implemented for this config.
- """
- return self.to_dict()
- def to_json_string(self, use_diff: bool = True) -> str:
- """
- Serializes this instance to a JSON string.
- Args:
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default `PreTrainedConfig()`
- is serialized to JSON string.
- Returns:
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
- """
- config_dict = self.to_diff_dict() if use_diff else self.to_dict()
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
- def update(self, **kwargs):
- """
- Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
- returning all the unused kwargs.
- Args:
- kwargs (`dict[str, Any]`):
- Dictionary of attributes to tentatively update this class.
- Returns:
- `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
- """
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(self, key):
- setattr(self, key, value)
- to_remove.append(key)
- # Remove all the attributes that were updated, without modifying the input dict
- unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
- return unused_kwargs
- @dataclass
- class AutoRoundConfig(QuantizationConfigMixin):
- """This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded AutoRound quantization.
- Args:
- bits (`int`, *optional*, defaults to 4):
- The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
- group_size (`int`, *optional*, defaults to 128): Group-size value
- sym (`bool`, *optional*, defaults to `True`): Symmetric quantization or not
- backend (`str`, *optional*, defaults to `"auto"`): The kernel to use, e.g., ipex,marlin, exllamav2, triton, etc. Ref. https://github.com/intel/auto-round?tab=readme-ov-file#specify-backend
- """
- def __init__(
- self,
- bits: int = 4,
- group_size: int = 128,
- sym: bool = True,
- backend: str = "auto",
- **kwargs,
- ):
- self.bits = bits
- self.group_size = group_size
- self.sym = sym
- self.backend = backend
- self.packing_format = "auto_round:gptq"
- if kwargs is not None:
- for key, value in kwargs.items():
- setattr(self, key, value)
- self.quant_method = QuantizationMethod.AUTOROUND
- self.post_init()
- def post_init(self):
- r"""Safety checker that arguments are correct."""
- if self.bits not in [2, 3, 4, 8]:
- raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
- if self.group_size != -1 and self.group_size <= 0:
- raise ValueError("group_size must be greater than 0 or equal to -1")
- def get_loading_attributes(self):
- loading_attributes_dict = {"backend": self.backend}
- return loading_attributes_dict
- def to_dict(self):
- config_dict = super().to_dict()
- return config_dict
- @classmethod
- def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
- quant_method = config_dict["quant_method"]
- if "auto-round" not in quant_method and "gptq" not in quant_method and "awq" not in quant_method:
- raise NotImplementedError(
- "Failed to convert to auto_round format. Only `gptqv1`, `awq`, and `auto-round` formats are supported."
- )
- if "gptq" in quant_method and "meta" in config_dict:
- raise NotImplementedError("Failed to convert gptq format to auto_round format. Only supports `gptqv1`")
- if "awq" in quant_method and config_dict.get("version", "gemm") != "gemm":
- raise NotImplementedError(
- "Failed to convert awq format to auto_round format. Only supports awq format with gemm version"
- )
- if "auto-round" not in quant_method:
- config_dict["packing_format"] = f"auto_round:{quant_method}"
- return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
- @dataclass
- class HqqConfig(QuantizationConfigMixin):
- """
- This is wrapper around hqq's BaseQuantizeConfig.
- Args:
- nbits (`int`, *optional*, defaults to 4):
- Number of bits. Supported values are (8, 4, 3, 2, 1).
- group_size (`int`, *optional*, defaults to 64):
- Group-size value. Supported values are any value that is divisible by weight.shape[axis]).
- view_as_float (`bool`, *optional*, defaults to `False`):
- View the quantized weight as float (used in distributed training) if set to `True`.
- axis (`Optional[int]`, *optional*):
- Axis along which grouping is performed. Supported values are 0 or 1.
- dynamic_config (dict, *optional*):
- Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
- If set, each layer specified by its id will use its dedicated quantization configuration.
- skip_modules (`list[str]`, *optional*, defaults to `['lm_head']`):
- List of `nn.Linear` layers to skip.
- kwargs (`dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
- def __init__(
- self,
- nbits: int = 4,
- group_size: int = 64,
- view_as_float: bool = False,
- axis: int | None = None,
- dynamic_config: dict | None = None,
- skip_modules: list[str] = ["lm_head"],
- **kwargs,
- ):
- if is_hqq_available():
- from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
- else:
- raise ImportError(
- "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
- )
- if axis is None:
- axis = 1
- logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
- if axis not in [0, 1]:
- raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
- if dynamic_config is not None:
- self.quant_config = {}
- for key in dynamic_config:
- self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
- else:
- self.quant_config = HQQBaseQuantizeConfig(
- nbits=nbits, group_size=group_size, view_as_float=view_as_float, axis=axis
- )
- self.quant_method = QuantizationMethod.HQQ
- self.skip_modules = skip_modules
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- @classmethod
- def from_dict(cls, config: dict[str, Any]):
- """
- Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
- """
- instance = cls()
- instance.quant_config = config["quant_config"]
- instance.skip_modules = config["skip_modules"]
- return instance
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return {
- "quant_config": self.quant_config,
- "quant_method": self.quant_method,
- "skip_modules": self.skip_modules,
- }
- def __repr__(self):
- config_dict = self.to_dict()
- return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
- def to_diff_dict(self) -> dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
- Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
- # get the default config dict
- default_config_dict = HqqConfig().to_dict()
- serializable_config_dict = {}
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if value != default_config_dict[key]:
- serializable_config_dict[key] = value
- return serializable_config_dict
- @dataclass
- class BitsAndBytesConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `bitsandbytes`.
- Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
- then more arguments will be added to this class.
- Args:
- load_in_8bit (`bool`, *optional*, defaults to `False`):
- This flag is used to enable 8-bit quantization with LLM.int8().
- load_in_4bit (`bool`, *optional*, defaults to `False`):
- This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
- `bitsandbytes`.
- llm_int8_threshold (`float`, *optional*, defaults to 6.0):
- This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
- Multiplication for Transformers at Scale` paper: https://huggingface.co/papers/2208.07339 Any hidden states value
- that is above this threshold will be considered an outlier and the operation on those values will be done
- in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
- there are some exceptional systematic outliers that are very differently distributed for large models.
- These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
- magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
- but a lower threshold might be needed for more unstable models (small models, fine-tuning).
- llm_int8_skip_modules (`list[str]`, *optional*):
- An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
- Jukebox that has several heads in different places and not necessarily at the last position. For example
- for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
- llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
- This flag is used for advanced use cases and users that are aware of this feature. If you want to split
- your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
- this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
- operations will not be run on CPU.
- llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
- This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
- have to be converted back and forth for the backward pass.
- bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
- This sets the computational type which might be different than the input type. For example, inputs might be
- fp32, but computation can be set to bf16 for speedups.
- bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
- This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
- which are specified by `fp4` or `nf4`.
- bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
- This flag is used for nested quantization where the quantization constants from the first quantization are
- quantized again.
- bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
- This sets the storage type to pack the quantized 4-bit params.
- kwargs (`dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
- def __init__(
- self,
- load_in_8bit=False,
- load_in_4bit=False,
- llm_int8_threshold=6.0,
- llm_int8_skip_modules=None,
- llm_int8_enable_fp32_cpu_offload=False,
- llm_int8_has_fp16_weight=False,
- bnb_4bit_compute_dtype=None,
- bnb_4bit_quant_type="fp4",
- bnb_4bit_use_double_quant=False,
- bnb_4bit_quant_storage=None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.BITS_AND_BYTES
- if load_in_4bit and load_in_8bit:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
- self._load_in_8bit = load_in_8bit
- self._load_in_4bit = load_in_4bit
- self.llm_int8_threshold = llm_int8_threshold
- self.llm_int8_skip_modules = llm_int8_skip_modules
- self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
- self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
- self.bnb_4bit_quant_type = bnb_4bit_quant_type
- self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
- if bnb_4bit_compute_dtype is None:
- self.bnb_4bit_compute_dtype = torch.float32
- elif isinstance(bnb_4bit_compute_dtype, str):
- self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
- elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
- self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
- else:
- raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
- if bnb_4bit_quant_storage is None:
- self.bnb_4bit_quant_storage = torch.uint8
- elif isinstance(bnb_4bit_quant_storage, str):
- if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
- raise ValueError(
- "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
- )
- self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
- elif isinstance(bnb_4bit_quant_storage, torch.dtype):
- self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
- else:
- raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
- if kwargs:
- logger.info(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
- self.post_init()
- @property
- def load_in_4bit(self):
- return self._load_in_4bit
- @load_in_4bit.setter
- def load_in_4bit(self, value: bool):
- if not isinstance(value, bool):
- raise TypeError("load_in_4bit must be a boolean")
- if self.load_in_8bit and value:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
- self._load_in_4bit = value
- @property
- def load_in_8bit(self):
- return self._load_in_8bit
- @load_in_8bit.setter
- def load_in_8bit(self, value: bool):
- if not isinstance(value, bool):
- raise TypeError("load_in_8bit must be a boolean")
- if self.load_in_4bit and value:
- raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
- self._load_in_8bit = value
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if not isinstance(self.load_in_4bit, bool):
- raise TypeError("load_in_4bit must be a boolean")
- if not isinstance(self.load_in_8bit, bool):
- raise TypeError("load_in_8bit must be a boolean")
- if not isinstance(self.llm_int8_threshold, float):
- raise TypeError("llm_int8_threshold must be a float")
- if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
- raise TypeError("llm_int8_skip_modules must be a list of strings")
- if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
- raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean")
- if not isinstance(self.llm_int8_has_fp16_weight, bool):
- raise TypeError("llm_int8_has_fp16_weight must be a boolean")
- if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
- raise TypeError("bnb_4bit_compute_dtype must be torch.dtype")
- if not isinstance(self.bnb_4bit_quant_type, str):
- raise TypeError("bnb_4bit_quant_type must be a string")
- if not isinstance(self.bnb_4bit_use_double_quant, bool):
- raise TypeError("bnb_4bit_use_double_quant must be a boolean")
- def is_quantizable(self):
- r"""
- Returns `True` if the model is quantizable, `False` otherwise.
- """
- return self.load_in_8bit or self.load_in_4bit
- def quantization_method(self):
- r"""
- This method returns the quantization method used for the model. If the model is not quantizable, it returns
- `None`.
- """
- if self.load_in_8bit:
- return "llm_int8"
- elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
- return "fp4"
- elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
- return "nf4"
- else:
- return None
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
- output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
- output["load_in_4bit"] = self.load_in_4bit
- output["load_in_8bit"] = self.load_in_8bit
- return output
- def __repr__(self):
- config_dict = self.to_dict()
- return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
- def to_diff_dict(self) -> dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
- Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
- # get the default config dict
- default_config_dict = BitsAndBytesConfig().to_dict()
- serializable_config_dict = {}
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if value != default_config_dict[key]:
- serializable_config_dict[key] = value
- return serializable_config_dict
- class ExllamaVersion(int, Enum):
- ONE = 1
- TWO = 2
- @dataclass
- class GPTQConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `optimum` api for GPTQ quantization relying on the gptqmodel backend.
- Args:
- bits (`int`):
- The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
- tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
- The tokenizer used to process the dataset. You can pass either:
- - A custom tokenizer object.
- - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
- using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
- dataset (`Union[list[str]]`, *optional*):
- The dataset used for quantization. You can provide your own dataset in a list of string or just use the
- original datasets used in GPTQ paper ['wikitext2','c4','c4-new']
- group_size (`int`, *optional*, defaults to 128):
- The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
- damp_percent (`float`, *optional*, defaults to 0.1):
- The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1.
- desc_act (`bool`, *optional*, defaults to `False`):
- Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly
- speed up inference but the perplexity may become slightly worse. Also known as act-order.
- act_group_aware (`bool`, *optional*, defaults to `True`):
- Use GAR (group aware activation order) during quantization. Has measurable positive impact on quantization
- quality. Only applicable when `desc_act = False`. Will forced to be `False` when `desc_act = True`.
- sym (`bool`, *optional*, defaults to `True`):
- Whether to use symmetric quantization.
- true_sequential (`bool`, *optional*, defaults to `True`):
- Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing
- the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes
- quantization using inputs that have passed through the previously quantized layers.
- format (`str`, *optional*, defaults to `"gptq"`):
- GPTQ weight format. `gptq` (v1) is supported by gptqmodel. `gptq_v2` is gptqmodel only.
- meta (`dict[str, any]`, *optional*):
- Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
- i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
- backend (`str`, *optional*):
- Controls which kernel to use. Valid values for gptqmodel are `auto`, `auto_trainable` and more. Ref gptqmodel backends:
- https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
- model_seqlen (`int`, *optional*):
- The maximum sequence length that the model can take.
- block_name_to_quantize (`str`, *optional*):
- The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers)
- module_name_preceding_first_block (`list[str]`, *optional*):
- The layers that are preceding the first Transformer block.
- batch_size (`int`, *optional*, defaults to 1):
- The batch size used when processing the dataset
- pad_token_id (`int`, *optional*):
- The pad token id. Needed to prepare the dataset when `batch_size` > 1.
- max_input_length (`int`, *optional*):
- The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
- length. It is specific to the exllama backend with act-order.
- cache_block_outputs (`bool`, *optional*, defaults to `True`):
- Whether to cache block outputs to reuse as inputs for the succeeding block.
- modules_in_block_to_quantize (`list[list[str]]`, *optional*):
- List of list of module names to quantize in the specified block. This argument is useful to exclude certain linear modules from being quantized.
- The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers.
- Example: `modules_in_block_to_quantize =[["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"]]`.
- In this example, we will first quantize the q,k,v layers simultaneously since they are independent.
- Then, we will quantize `self_attn.o_proj` layer with the q,k,v layers quantized. This way, we will get
- better results since it reflects the real input `self_attn.o_proj` will get when the model is quantized.
- """
- def __init__(
- self,
- bits: int,
- tokenizer: Any = None,
- dataset: list[str] | str | None = None,
- group_size: int = 128,
- damp_percent: float = 0.1,
- desc_act: bool = False,
- act_group_aware: bool = True,
- sym: bool = True,
- true_sequential: bool = True,
- format: str = "gptq",
- meta: dict[str, Any] | None = None,
- backend: str | None = None,
- model_seqlen: int | None = None,
- block_name_to_quantize: str | None = None,
- module_name_preceding_first_block: list[str] | None = None,
- batch_size: int = 1,
- pad_token_id: int | None = None,
- max_input_length: int | None = None,
- cache_block_outputs: bool = True,
- modules_in_block_to_quantize: list[list[str]] | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.GPTQ
- self.bits = bits
- self.tokenizer = tokenizer
- self.dataset = dataset
- self.group_size = group_size
- self.damp_percent = damp_percent
- self.desc_act = desc_act
- self.act_group_aware = act_group_aware
- self.sym = sym
- self.true_sequential = true_sequential
- self.format = format.lower()
- # Compatible with legacy field: checkpoint_format
- if kwargs.get("checkpoint_format") is not None:
- self.format = kwargs.pop("checkpoint_format").lower()
- self.meta = meta
- self.backend = backend.lower() if isinstance(backend, str) else backend
- self.model_seqlen = model_seqlen
- self.block_name_to_quantize = block_name_to_quantize
- self.module_name_preceding_first_block = module_name_preceding_first_block
- self.batch_size = batch_size
- self.pad_token_id = pad_token_id
- self.max_input_length = max_input_length
- self.cache_block_outputs = cache_block_outputs
- self.modules_in_block_to_quantize = modules_in_block_to_quantize
- self.post_init()
- def get_loading_attributes(self):
- attributes_dict = copy.deepcopy(self.__dict__)
- loading_attributes = ["max_input_length", "backend"]
- loading_attributes_dict = {i: j for i, j in attributes_dict.items() if i in loading_attributes}
- return loading_attributes_dict
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if self.bits not in [2, 3, 4, 8]:
- raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
- if self.group_size != -1 and self.group_size <= 0:
- raise ValueError("group_size must be greater than 0 or equal to -1")
- if not (0 < self.damp_percent < 1):
- raise ValueError("damp_percent must between 0 and 1.")
- if self.dataset is not None:
- if isinstance(self.dataset, str):
- if self.dataset not in ["wikitext2", "c4", "c4-new"]:
- raise ValueError(
- f"""You have entered a string value for dataset. You can only choose between
- ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
- )
- elif not isinstance(self.dataset, list):
- raise ValueError(
- f"""dataset needs to be either a list of string or a value in
- ['wikitext2','c4','c4-new'], but we found {self.dataset}"""
- )
- # act_group_order is only applicable when `desc_act = False`
- if self.desc_act and self.act_group_aware:
- self.act_group_aware = False
- logger.warning("`act_group_aware` has been auto-disabled as it is not compatible with `desc_act = True`.")
- # make sure backend default stays consistent with gptqmodel expectations
- if self.backend is None:
- self.backend = "auto"
- if self.modules_in_block_to_quantize is not None:
- optimum_version = version.parse(importlib.metadata.version("optimum"))
- if optimum_version < version.parse("1.15.0"):
- raise ValueError(
- "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
- )
- def to_dict(self) -> dict[str, Any]:
- config_dict = super().to_dict()
- # Compatible with legacy field: checkpoint_format
- config_dict["checkpoint_format"] = self.format
- return config_dict
- def to_dict_optimum(self):
- """
- Get compatible dict for optimum gptq config
- """
- return self.to_dict()
- @classmethod
- def from_dict_optimum(cls, config_dict):
- """
- Get compatible class with optimum gptq config dict
- """
- config = cls(**config_dict)
- return config
- @dataclass
- class AwqConfig(GPTQConfig):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `auto-awq` library awq quantization relying on auto_awq backend.
- Args:
- bits (`int`, *optional*, defaults to 4):
- The number of bits to quantize to.
- group_size (`int`, *optional*, defaults to 128):
- The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
- zero_point (`bool`, *optional*, defaults to `True`):
- Whether to use zero point quantization.
- backend (`AwqBackend`, *optional*, defaults to `AwqBackend.AUTO`):
- The quantization backend.
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
- """
- def __init__(
- self,
- bits: int = 4,
- group_size: int = 128,
- zero_point: bool = True,
- backend: AwqBackend = AwqBackend.AUTO,
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- format = kwargs.pop("format", AwqFormat.GEMM)
- # Compatible with legacy field: version
- if kwargs.get("version") is not None:
- format = kwargs.pop("version").lower()
- # Compatible with legacy backend
- if backend == AwqBackend.LEGACY_AWQ:
- backend = AwqBackend.AUTO
- self.zero_point = zero_point
- self.modules_to_not_convert = modules_to_not_convert
- super().__init__(bits=bits, group_size=group_size, backend=backend, format=format, **kwargs)
- self.quant_method = QuantizationMethod.AWQ
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if self.backend == "llm-awq":
- self.format = AwqFormat.LLM_AWQ
- self.backend = AwqBackend.AUTO
- if self.format not in AwqFormat.__members__.values():
- raise ValueError(f"Invalid format '{self.format}'. Must be one of: {[b.value for b in AwqFormat]}")
- if self.backend not in AwqBackend.__members__.values():
- raise ValueError(f"Invalid backend '{self.backend}'. Must be one of: {[b.value for b in AwqBackend]}")
- def to_dict(self) -> dict[str, Any]:
- config_dict = super().to_dict()
- config_dict.pop("checkpoint_format")
- # Compatible with legacy field: version
- config_dict["version"] = self.format
- return config_dict
- @dataclass
- class AqlmConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about `aqlm` parameters.
- Args:
- in_group_size (`int`, *optional*, defaults to 8):
- The group size along the input dimension.
- out_group_size (`int`, *optional*, defaults to 1):
- The group size along the output dimension. It's recommended to always use 1.
- num_codebooks (`int`, *optional*, defaults to 1):
- Number of codebooks for the Additive Quantization procedure.
- nbits_per_codebook (`int`, *optional*, defaults to 16):
- Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
- linear_weights_not_to_quantize (`Optional[list[str]]`, *optional*):
- List of full paths of `nn.Linear` weight parameters that shall not be quantized.
- kwargs (`dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
- def __init__(
- self,
- in_group_size: int = 8,
- out_group_size: int = 1,
- num_codebooks: int = 1,
- nbits_per_codebook: int = 16,
- linear_weights_not_to_quantize: list[str] | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.AQLM
- self.in_group_size = in_group_size
- self.out_group_size = out_group_size
- self.num_codebooks = num_codebooks
- self.nbits_per_codebook = nbits_per_codebook
- self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if not isinstance(self.in_group_size, int):
- raise TypeError("in_group_size must be a float")
- if not isinstance(self.out_group_size, int):
- raise TypeError("out_group_size must be a float")
- if not isinstance(self.num_codebooks, int):
- raise TypeError("num_codebooks must be a float")
- if not isinstance(self.nbits_per_codebook, int):
- raise TypeError("nbits_per_codebook must be a float")
- if self.linear_weights_not_to_quantize is not None and not isinstance(
- self.linear_weights_not_to_quantize, list
- ):
- raise ValueError("linear_weights_not_to_quantize must be a list of strings")
- if self.linear_weights_not_to_quantize is None:
- self.linear_weights_not_to_quantize = []
- @dataclass
- class VptqLayerConfig(QuantizationConfigMixin):
- """
- This is used to explain vptq config params for each layer
- Args:
- enable_norm (`bool`, *optional*, defaults to `True`): to control if we have scale/bias for fp-weight
- enable_perm (`bool`, *optional*, defaults to `True`): to perm input_channel or not
- group_num (`int`, *optional*, defaults to `1`): how many single groups for vector-quantization
- group_size (`int`, *optional*, defaults to `-1`): depends on out-features
- indices_as_float (`bool`, *optional*, defaults to `False`): for Finetuning
- is_indice_packed (`bool`, *optional*, defaults to `True`): should always be True
- num_centroids (`list`, *optional*, defaults to `[-1, -1]`): centroid numbers of clusters
- num_res_centroids (`list`, *optional*, defaults to `[-1, -1]`): ditto for residual
- outlier_size (`int`, *optional*, defaults to `1`): outliers
- vector_lens (`list`, *optional*, defaults to `[-1, -1]`): centroid vector length in quantization
- """
- def __init__(
- self,
- enable_norm: bool = True,
- enable_perm: bool = True,
- group_num: int = 1,
- group_size: int = -1,
- in_features: int = -1,
- indices_as_float: bool = False,
- is_indice_packed: bool = True,
- num_centroids: list = [-1, -1],
- num_res_centroids: list = [-1, -1],
- out_features: int = -1,
- outlier_size: int = 0,
- vector_lens: list = [-1, -1],
- **kwargs,
- ):
- self.enable_norm = enable_norm
- self.enable_perm = enable_perm
- self.group_num = group_num
- self.group_size = group_size
- self.in_features = in_features
- self.indices_as_float = indices_as_float
- self.is_indice_packed = is_indice_packed
- self.num_centroids = num_centroids
- self.num_res_centroids = num_res_centroids
- self.out_features = out_features
- self.outlier_size = outlier_size
- self.vector_lens = vector_lens
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- if self.is_indice_packed is False:
- raise ValueError("is_indice_packed should always be True")
- @dataclass
- class VptqConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about `vptq` parameters.
- Args:
- enable_proxy_error (`bool`, *optional*, defaults to `False`): calculate proxy error for each layer
- config_for_layers (`Dict`, *optional*, defaults to `{}`): quantization params for each layer
- shared_layer_config (`Dict`, *optional*, defaults to `{}`): shared quantization params among layers
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- kwargs (`dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
- def __init__(
- self,
- enable_proxy_error: bool = False,
- config_for_layers: dict[str, Any] = {},
- shared_layer_config: dict[str, Any] = {},
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.VPTQ
- self.enable_proxy_error = enable_proxy_error
- self.config_for_layers: dict[str, Any] = config_for_layers
- self.shared_layer_config: dict[str, Any] = shared_layer_config
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- for layer_param in self.config_for_layers.values():
- VptqLayerConfig(**layer_param)
- if self.enable_proxy_error is True:
- raise ValueError("enable_proxy_error should always be False until we support training")
- @dataclass
- class QuantoConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `quanto`.
- Args:
- weights (`str`, *optional*, defaults to `"int8"`):
- The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
- activations (`str`, *optional*):
- The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
- """
- def __init__(
- self,
- weights="int8",
- activations=None,
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.QUANTO
- self.weights = weights
- self.activations = activations
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- accepted_weights = ["float8", "int8", "int4", "int2"]
- accepted_activations = [None, "int8", "float8"]
- if self.weights not in accepted_weights:
- raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
- if self.activations not in accepted_activations:
- raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
- @dataclass
- class EetqConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using `eetq`.
- Args:
- weights (`str`, *optional*, defaults to `"int8"`):
- The target dtype for the weights. Supported value is only "int8"
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- """
- def __init__(
- self,
- weights: str = "int8",
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.EETQ
- self.weights = weights
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- accepted_weights = ["int8"]
- if self.weights not in accepted_weights:
- raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
- class CompressedTensorsConfig(QuantizationConfigMixin):
- """
- This is a wrapper class that handles compressed-tensors quantization config options.
- It is a wrapper around `compressed_tensors.QuantizationConfig`
- Args:
- config_groups (`typing.dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.list[str]]]`, *optional*):
- dictionary mapping group name to a quantization scheme definition
- format (`str`, *optional*, defaults to `"dense"`):
- format the model is represented as. Set `run_compressed` True to execute model as the
- compressed format if not `dense`
- quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`):
- status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen'
- kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*):
- specifies quantization of the kv cache. If None, kv cache is not quantized.
- global_compression_ratio (`typing.Union[float, NoneType]`, *optional*):
- 0-1 float percentage of model compression
- ignore (`typing.Union[typing.list[str], NoneType]`, *optional*):
- layer names or types to not quantize, supports regex prefixed by 're:'
- sparsity_config (`typing.dict[str, typing.Any]`, *optional*):
- configuration for sparsity compression
- quant_method (`str`, *optional*, defaults to `"compressed-tensors"`):
- do not override, should be compressed-tensors
- run_compressed (`bool`, *optional*, defaults to `True`): alter submodules (usually linear) in order to
- emulate compressed model execution if True, otherwise use default submodule
- """
- def __init__(
- self,
- config_groups: dict[str, Union["QuantizationScheme", list[str]]] | None = None, # noqa: F821
- format: str = "dense",
- quantization_status: "QuantizationStatus" = "initialized", # noqa: F821
- kv_cache_scheme: Optional["QuantizationArgs"] = None, # noqa: F821
- global_compression_ratio: float | None = None,
- ignore: list[str] | None = None,
- sparsity_config: dict[str, Any] | None = None,
- quant_method: str = "compressed-tensors",
- run_compressed: bool = True,
- **kwargs,
- ):
- if is_compressed_tensors_available():
- from compressed_tensors.config import SparsityCompressionConfig
- from compressed_tensors.quantization import QuantizationConfig
- else:
- raise ImportError(
- "compressed_tensors is not installed and is required for compressed-tensors quantization. Please install it with `pip install compressed-tensors`."
- )
- self.quantization_config = None
- self.sparsity_config = None
- self.run_compressed = run_compressed
- # parse from dict to load nested QuantizationScheme objects
- if config_groups or kv_cache_scheme:
- self.quantization_config = QuantizationConfig.model_validate(
- {
- "config_groups": config_groups,
- "quant_method": quant_method,
- "format": format,
- "quantization_status": quantization_status,
- "kv_cache_scheme": kv_cache_scheme,
- "global_compression_ratio": global_compression_ratio,
- "ignore": ignore,
- **kwargs,
- }
- )
- if sparsity_config:
- self.sparsity_config = SparsityCompressionConfig.load_from_registry(
- sparsity_config.get("format"), **sparsity_config
- )
- self.quant_method = QuantizationMethod.COMPRESSED_TENSORS
- def post_init(self):
- if self.run_compressed:
- if self.is_sparsification_compressed:
- logger.warning(
- "`run_compressed` is only supported for quantized_compressed models"
- " and not for sparsified models. Setting `run_compressed=False`"
- )
- self.run_compressed = False
- elif not self.is_quantization_compressed:
- logger.warning(
- "`run_compressed` is only supported for compressed models. Setting `run_compressed=False`"
- )
- self.run_compressed = False
- @classmethod
- def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
- """
- Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters.
- Optionally unwraps any args from the nested quantization_config
- Args:
- config_dict (`dict[str, Any]`):
- Dictionary that will be used to instantiate the configuration object.
- return_unused_kwargs (`bool`,*optional*, defaults to `False`):
- Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
- `PreTrainedModel`.
- kwargs (`dict[str, Any]`):
- Additional parameters from which to initialize the configuration object.
- Returns:
- [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
- """
- if "quantization_config" in config_dict:
- config_dict = dict(
- sparsity_config=config_dict.get("sparsity_config"),
- **config_dict["quantization_config"],
- )
- return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
- def to_dict(self) -> dict[str, Any]:
- """
- Quantization config to be added to config.json
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- quantization_config = {}
- if self.quantization_config is not None:
- quantization_config = self.quantization_config.model_dump()
- else:
- quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS
- if self.sparsity_config is not None:
- quantization_config["sparsity_config"] = self.sparsity_config.model_dump()
- else:
- quantization_config["sparsity_config"] = {}
- return quantization_config
- def to_diff_dict(self) -> dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
- Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
- # get the default config dict
- default_config_dict = CompressedTensorsConfig().to_dict()
- serializable_config_dict = {}
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if key not in default_config_dict or value != default_config_dict[key]:
- serializable_config_dict[key] = value
- return serializable_config_dict
- def get_loading_attributes(self):
- return {"run_compressed": self.run_compressed}
- @property
- def is_quantized(self):
- return bool(self.quantization_config) and bool(self.quantization_config.config_groups)
- @property
- def is_quantization_compressed(self):
- from compressed_tensors.quantization import QuantizationStatus
- qc = self.quantization_config
- return self.is_quantized and (qc is not None and qc.quantization_status == QuantizationStatus.COMPRESSED)
- @property
- def is_sparsification_compressed(self):
- from compressed_tensors.config import (
- CompressionFormat,
- SparsityCompressionConfig,
- )
- return (
- isinstance(self.sparsity_config, SparsityCompressionConfig)
- and self.sparsity_config.format != CompressionFormat.dense.value
- )
- @dataclass
- class FbgemmFp8Config(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using fbgemm fp8 quantization.
- Args:
- activation_scale_ub (`float`, *optional*, defaults to 1200.0):
- The activation scale upper bound. This is used when quantizing the input activation.
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- """
- def __init__(
- self,
- activation_scale_ub: float = 1200.0,
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.FBGEMM_FP8
- self.activation_scale_ub = activation_scale_ub
- self.modules_to_not_convert = modules_to_not_convert
- def get_loading_attributes(self):
- attributes_dict = copy.deepcopy(self.__dict__)
- loading_attributes = ["activation_scale_ub"]
- loading_attributes_dict = {i: j for i, j in attributes_dict.items() if i in loading_attributes}
- return loading_attributes_dict
- @dataclass
- class HiggsConfig(QuantizationConfigMixin):
- """
- HiggsConfig is a configuration class for quantization using the HIGGS method.
- Args:
- bits (int, *optional*, defaults to 4):
- Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4.
- p (int, *optional*, defaults to 2):
- Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2.
- modules_to_not_convert (`list`, *optional*, default to ["lm_head"]):
- List of linear layers that should not be quantized.
- hadamard_size (int, *optional*, defaults to 512):
- Hadamard size for the HIGGS method. Default is 512. Input dimension of matrices is padded to this value. Decreasing this below 512 will reduce the quality of the quantization.
- group_size (int, *optional*, defaults to 256):
- Group size for the HIGGS method. Can be 64, 128 or 256. Decreasing it barely affects the performance. Default is 256. Must be a divisor of hadamard_size.
- tune_metadata ('dict', *optional*, defaults to {}):
- Module-wise metadata (gemm block shapes, GPU metadata, etc.) for saving the kernel tuning results. Default is an empty dictionary. Is set automatically during tuning.
- """
- def __init__(
- self,
- bits: int = 4,
- p: int = 2,
- modules_to_not_convert: list[str] | None = None,
- hadamard_size: int = 512,
- group_size: int = 256,
- tune_metadata: dict[str, Any] | None = None,
- **kwargs,
- ):
- if tune_metadata is None:
- tune_metadata = {}
- self.quant_method = QuantizationMethod.HIGGS
- self.bits = bits
- self.p = p
- self.modules_to_not_convert = modules_to_not_convert
- self.hadamard_size = hadamard_size
- self.group_size = group_size
- self.tune_metadata = tune_metadata
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if self.bits not in [2, 3, 4]:
- raise ValueError("bits must be 2, 3, or 4")
- if self.p not in [1, 2]:
- raise ValueError("p must be 1 or 2. 2 is always better in practice")
- if self.group_size not in [64, 128, 256]:
- raise ValueError("group_size must be 64, 128, or 256")
- if self.hadamard_size % self.group_size != 0:
- raise ValueError("hadamard_size must be divisible by group_size")
- @dataclass
- class FPQuantConfig(QuantizationConfigMixin):
- """
- FPQuantConfig is a configuration class for quantization using the FPQuant method.
- Args:
- forward_dtype (`str`, *optional*, defaults to `"nvfp4"`):
- The dtype to use for the forward pass.
- forward_method (`str`, *optional*, defaults to `"abs_max"`):
- The scaling to use for the forward pass. Can be `"abs_max"` or `"quest"`. `"abs_max"` is better for PTQ, `"quest"` is better for QAT.
- backward_dtype (`str`, *optional*, defaults to `"bf16"`):
- The dtype to use for the backward pass.
- store_master_weights (`bool`, *optional*, defaults to `False`):
- Whether to store the master weights. Needed for QAT over layer weights.
- hadamard_group_size (`int`, *optional*):
- The group size for the hadamard transform before quantization for `"quest"` it matches the MXFP4 group size (32). If `None`, it will be set to 16 for `"nvfp4"` and 32 for `"mxfp4"`.
- pseudoquantization (`bool`, *optional*, defaults to `False`):
- Whether to use Triton-based pseudo-quantization. Is mandatory for non-Blackwell GPUs. Doesn't provide any speedup. For debugging purposes.
- transform_init (`str`, *optional*, defaults to `"hadamard"`): a method to initialize the pre-processing matrix with. Can be `"hadamard"`, `"identity"` or `"gsr"`.
- modules_to_not_convert (`list`, *optional*):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- """
- def __init__(
- self,
- forward_dtype: str = "nvfp4",
- forward_method: str = "abs_max",
- backward_dtype: str = "bf16",
- store_master_weights: bool = False,
- hadamard_group_size: int | None = None,
- pseudoquantization: bool = False,
- transform_init: str = "hadamard",
- modules_to_not_convert: list[str] | None = None,
- **kwargs,
- ):
- self.forward_dtype = forward_dtype
- self.forward_method = forward_method
- self.backward_dtype = backward_dtype
- self.store_master_weights = store_master_weights
- self.hadamard_group_size = hadamard_group_size
- self.pseudoquantization = pseudoquantization
- self.transform_init = transform_init
- self.modules_to_not_convert = modules_to_not_convert
- self.quant_method = QuantizationMethod.FPQUANT
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if self.hadamard_group_size is None:
- if self.forward_dtype == "nvfp4":
- self.hadamard_group_size = 16
- else:
- self.hadamard_group_size = 32
- if self.forward_dtype == "mxfp4":
- if self.forward_method not in ["abs_max", "quest"]:
- raise ValueError("Only 'abs_max' and 'quest' are supported for forward_method for 'mxfp4'.")
- if self.hadamard_group_size is None:
- self.hadamard_group_size = 32
- if self.hadamard_group_size not in [32, 64, 128]:
- raise ValueError("Only a `hadamard_group_size` of [32, 64, 128] is supported for 'mxfp4'.")
- elif self.forward_dtype == "nvfp4":
- if self.forward_method != "abs_max":
- raise ValueError("Only 'abs_max' is supported for forward_method for 'nvfp4'.")
- if self.hadamard_group_size is None:
- self.hadamard_group_size = 16
- if self.hadamard_group_size not in [16, 32, 64, 128]:
- raise ValueError("Only a `hadamard_group_size` of [16, 32, 64, 128] is supported for 'nvfp4'.")
- else:
- raise ValueError("Only 'mxfp4' and 'nvfp4' are supported for forward_dtype for now.")
- if self.backward_dtype not in ["bf16", "mxfp8", "mxfp4"]:
- raise ValueError("Only 'bf16', 'mxfp8' and 'mxfp4' are supported for backward_dtype for now.")
- if self.backward_dtype != "bf16" and self.forward_dtype != "mxfp4":
- raise ValueError("Only 'mxfp4' forward is compatible with non-bf16 backwards for now.")
- if self.transform_init not in ["hadamard", "identity", "gsr"]:
- raise ValueError("Only 'hadamard', 'identity' and 'gsr' are supported for transform_init.")
- if self.modules_to_not_convert is None:
- self.modules_to_not_convert = ["lm_head"]
- @dataclass
- class TorchAoConfig(QuantizationConfigMixin):
- """Config class for torchao quantization/sparsity techniques.
- Args:
- quant_type (`AOBaseConfig`):
- A torchao `AOBaseConfig` instance specifying the quantization type, e.g.
- `Int4WeightOnlyConfig(group_size=32)`, `Int8WeightOnlyConfig()`,
- `Int8DynamicActivationInt8WeightConfig()`, `Float8WeightOnlyConfig()`, etc.
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- include_input_output_embeddings (`bool`, *optional*, defaults to `False`):
- Whether to include embedding in quantization or not, input embedding will be removed from
- the module_not_to_convert list as well if this flag is set.
- untie_embedding_weights (`bool`, *optional*, defaults to `False`):
- Whether to untie the weights when we are quantizing input embedding weights that is tied
- to other weights.
- Example:
- ```python
- from torchao.quantization import Int4WeightOnlyConfig
- quantization_config = TorchAoConfig(Int4WeightOnlyConfig(group_size=32))
- model = AutoModelForCausalLM.from_pretrained(
- model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config
- )
- ```
- """
- quant_method: QuantizationMethod
- quant_type: "AOBaseConfig" # noqa: F821
- modules_to_not_convert: list | None
- include_input_output_embeddings: bool
- untie_embedding_weights: bool
- def __init__(
- self,
- quant_type: "AOBaseConfig", # noqa: F821
- modules_to_not_convert: list | None = None,
- include_input_output_embeddings: bool = False,
- untie_embedding_weights: bool = False,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.TORCHAO
- self.quant_type = quant_type
- self.modules_to_not_convert = modules_to_not_convert
- self.include_input_output_embeddings = include_input_output_embeddings
- self.untie_embedding_weights = untie_embedding_weights
- self.post_init()
- def post_init(self):
- """Validate configuration and set defaults."""
- if not is_torchao_available():
- raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
- if isinstance(self.quant_type, str):
- raise ValueError(
- f"String-based quantization type '{self.quant_type}' is no longer supported. "
- f"Please use the corresponding Config object directly, e.g. "
- f"TorchAoConfig(Int4WeightOnlyConfig(group_size=32)) instead of "
- f"TorchAoConfig('int4_weight_only', group_size=32)."
- )
- from torchao.quantization.quant_api import AOBaseConfig
- if not isinstance(self.quant_type, AOBaseConfig):
- raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type)}")
- def get_apply_tensor_subclass(self):
- """Return the quantization config to apply."""
- return self.quant_type
- def to_dict(self):
- """Convert configuration to a dictionary."""
- d = super().to_dict()
- from torchao.core.config import config_to_dict
- d["quant_type"] = {"default": config_to_dict(self.quant_type)}
- return d
- @classmethod
- def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
- """Create configuration from a dictionary."""
- from torchao.core.config import config_from_dict
- config_dict = config_dict.copy()
- quant_type = config_dict.pop("quant_type")
- # Check if we only have one key which is "default"
- # In the future we may update this
- assert len(quant_type) == 1 and "default" in quant_type, (
- "Expected only one key 'default' in quant_type dictionary"
- )
- quant_type = quant_type["default"]
- quant_type = config_from_dict(quant_type)
- return cls(quant_type=quant_type, **config_dict)
- @dataclass
- class BitNetQuantConfig(QuantizationConfigMixin):
- """
- Configuration class for applying BitNet quantization.
- Args:
- modules_to_not_convert (`Optional[List]`, *optional*):
- Optionally, provides a list of full paths of `nn.Linear` weight parameters
- that shall not be quantized. Defaults to None.
- linear_class (`str`, *optional*, defaults to `"bitlinear"`):
- The type of linear class to use. Can be either `bitlinear` or `autobitlinear`.
- quantization_mode (`str`, *optional*, defaults to `"offline"`):
- The quantization mode to use. Can be either `online` or `offline`.
- In `online` mode, the weight quantization parameters are calculated dynamically
- during each forward pass (e.g., based on the current weight values). This can
- adapt to weight changes during training (Quantization-Aware Training - QAT).
- In `offline` mode, quantization parameters are pre-calculated *before* inference.
- These parameters are then fixed and loaded into the quantized model. This
- generally results in lower runtime overhead compared to online quantization.
- use_rms_norm (`bool`, *optional*, defaults to `False`):
- Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
- of normalizing activations before quantization/packing.
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
- The epsilon value used in the RMSNorm layer for numerical stability.
- kwargs (`dict[str, Any]`, *optional*):
- Additional keyword arguments that may be used by specific quantization
- backends or future versions.
- """
- def __init__(
- self,
- modules_to_not_convert: list | None = None,
- linear_class: str = "bitlinear",
- quantization_mode: str = "offline",
- use_rms_norm: bool = False,
- rms_norm_eps: float | None = 1e-6,
- **kwargs,
- ):
- if linear_class not in ["bitlinear", "autobitlinear"]:
- raise ValueError(f"linear_class must be either 'bitlinear' or 'autobitlinear', but got {linear_class}")
- if quantization_mode not in ["online", "offline"]:
- raise ValueError(f"quantization_mode must be either 'online' or 'offline', but got {quantization_mode}")
- self.quant_method = QuantizationMethod.BITNET
- self.modules_to_not_convert = modules_to_not_convert
- self.linear_class = linear_class
- self.quantization_mode = quantization_mode
- self.use_rms_norm = use_rms_norm
- self.rms_norm_eps = rms_norm_eps
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- @dataclass
- class SpQRConfig(QuantizationConfigMixin):
- """
- This is a wrapper class about `spqr` parameters. Refer to the original publication for more details.
- Args:
- bits (`int`, *optional*, defaults to 3):
- Specifies the bit count for the weights and first order zero-points and scales.
- Currently only bits = 3 is supported.
- beta1 (`int`, *optional*, defaults to 16):
- SpQR tile width. Currently only beta1 = 16 is supported.
- beta2 (`int`, *optional*, defaults to 16):
- SpQR tile height. Currently only beta2 = 16 is supported.
- shapes (`Optional`, *optional*):
- A dictionary holding the shape of each object. We need this because it's impossible
- to deduce the exact size of the parameters just from bits, beta1, beta2.
- modules_to_not_convert (`Optional[list[str]]`, *optional*):
- Optionally, provides a list of full paths of `nn.Linear` weight parameters that shall not be quantized.
- Defaults to None.
- kwargs (`dict[str, Any]`, *optional*):
- Additional parameters from which to initialize the configuration object.
- """
- def __init__(
- self,
- bits: int = 3,
- beta1: int = 16,
- beta2: int = 16,
- shapes: dict[str, int] | None = None,
- modules_to_not_convert: list[str] | None = None,
- **kwargs,
- ):
- if shapes is None:
- shapes = {}
- self.shapes = shapes
- self.quant_method = QuantizationMethod.SPQR
- self.bits = bits
- self.beta1 = beta1
- self.beta2 = beta2
- self.modules_to_not_convert = modules_to_not_convert
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
- """
- if not isinstance(self.bits, int):
- raise TypeError("bits must be an int")
- if not isinstance(self.beta1, int):
- raise TypeError("beta1 must be an int")
- if not isinstance(self.beta2, int):
- raise TypeError("beta2 must be an int")
- if self.bits != 3:
- raise ValueError("SpQR currently only supports bits = 3")
- if self.beta1 != 16:
- raise ValueError("SpQR currently only supports beta1 = 16")
- if self.beta2 != 16:
- raise ValueError("SpQR currently only supports beta2 = 16")
- if not isinstance(self.shapes, dict):
- raise TypeError("shapes must be a dict")
- @dataclass
- class FineGrainedFP8Config(QuantizationConfigMixin):
- """
- FineGrainedFP8Config is a configuration class for fine-grained FP8 quantization used mainly for deepseek models.
- Args:
- activation_scheme (`str`, *optional*, defaults to `"dynamic"`):
- The scheme used for activation, the defaults and only support scheme for now is "dynamic".
- weight_block_size (`typing.tuple[int, int]`, *optional*, defaults to `(128, 128)`):
- The size of the weight blocks for quantization, default is (128, 128).
- dequantize (`bool`, *optional*, defaults to `False`):
- Whether to dequantize the model during loading.
- modules_to_not_convert (`list`, *optional*):
- A list of module names that should not be converted during quantization.
- """
- def __init__(
- self,
- activation_scheme: str = "dynamic",
- weight_block_size: tuple[int, int] = (128, 128),
- dequantize: bool = False,
- modules_to_not_convert: list | None = None,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.FP8
- self.modules_to_not_convert = modules_to_not_convert
- self.activation_scheme = activation_scheme
- self.weight_block_size = weight_block_size
- self.dequantize = dequantize
- self.post_init()
- def post_init(self):
- r"""
- Safety checker that arguments are correct
- """
- self.activation_scheme = self.activation_scheme.lower()
- if self.activation_scheme not in ["dynamic", "static"]:
- raise ValueError(f"Activation scheme {self.activation_scheme} not supported")
- if self.weight_block_size is not None and len(self.weight_block_size) != 2:
- raise ValueError("weight_block_size must be a tuple of two integers")
- if self.weight_block_size is not None and (self.weight_block_size[0] <= 0 or self.weight_block_size[1] <= 0):
- raise ValueError("weight_block_size must be a tuple of two positive integers")
- def get_loading_attributes(self):
- return {"dequantize": self.dequantize}
- class QuarkConfig(QuantizationConfigMixin):
- def __init__(
- self,
- **kwargs,
- ):
- if is_torch_available() and is_quark_available():
- from quark import __version__ as quark_version
- from quark.torch.export.config.config import JsonExporterConfig
- from quark.torch.export.main_export.quant_config_parser import QuantConfigParser
- from quark.torch.quantization.config.config import Config
- else:
- raise ImportError(
- "Quark is not installed. Please refer to https://quark.docs.amd.com/latest/install.html."
- )
- # This might be e.g. `"fp8"` or `"awq"`.
- self.custom_mode = kwargs["quant_method"]
- self.legacy = "export" not in kwargs
- if self.custom_mode in ["awq", "fp8"]:
- # Legacy (quark<1.0) or custom export.
- self.quant_config = QuantConfigParser.from_custom_config(kwargs, is_bias_quantized=False)
- self.json_export_config = JsonExporterConfig()
- else:
- self.quant_config = Config.from_dict(kwargs)
- if "export" in kwargs:
- # TODO: Remove this check once configuration version is handled natively by Quark.
- if "min_kv_scale" in kwargs["export"] and version.parse(quark_version) < version.parse("0.8"):
- min_kv_scale = kwargs["export"].pop("min_kv_scale")
- logger.warning(
- f"The parameter `min_kv_scale={min_kv_scale}` was found in the model config.json's `quantization_config.export` configuration, but this parameter is supported only for quark>=0.8. Ignoring this configuration parameter. Please update the `amd-quark` package."
- )
- self.json_export_config = JsonExporterConfig(**kwargs["export"])
- else:
- # Legacy (quark<1.0) or custom export.
- self.json_export_config = JsonExporterConfig()
- self.quant_method = QuantizationMethod.QUARK
- @dataclass
- class Mxfp4Config(QuantizationConfigMixin):
- """
- This is a wrapper class about all possible attributes and features that you can play with a model that has been
- loaded using mxfp4 quantization.
- Args:
- modules_to_not_convert (`list`, *optional*, default to `None`):
- The list of modules to not quantize, useful for quantizing models that explicitly require to have
- some modules left in their original precision.
- dequantize (`bool`, *optional*, default to `False`):
- Whether we dequantize the model to bf16 precision or not
- """
- def __init__(
- self,
- modules_to_not_convert: list | None = None,
- dequantize: bool = False,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.MXFP4
- self.modules_to_not_convert = modules_to_not_convert
- self.dequantize = dequantize
- def get_loading_attributes(self):
- return {"dequantize": self.dequantize}
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return {"quant_method": self.quant_method, "modules_to_not_convert": self.modules_to_not_convert}
- class MetalConfig(QuantizationConfigMixin):
- """
- Configuration class for Metal affine quantization targeting Apple Silicon (MPS) devices.
- This quantization method uses the ``mlx-quantization-metal-kernels`` Metal kernels from the Hugging Face Hub
- to perform affine quantization (scales + qbiases) with configurable bit-width and group size.
- The quantized weights are packed into ``uint32`` tensors and the forward pass uses fused
- dequantization + matmul Metal kernels.
- """
- def __init__(
- self,
- bits: int = 4,
- group_size: int = 64,
- modules_to_not_convert: list | None = None,
- dequantize: bool = False,
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.METAL
- self.bits = bits
- self.group_size = group_size
- self.modules_to_not_convert = modules_to_not_convert
- self.dequantize = dequantize
- self.post_init()
- def post_init(self):
- if self.bits not in (2, 4, 8):
- raise ValueError(f"Metal quantization only supports bits in {{2, 4, 8}}, got {self.bits}")
- if self.group_size <= 0:
- raise ValueError(f"group_size must be positive, got {self.group_size}")
- def get_loading_attributes(self):
- return {"dequantize": self.dequantize}
- def to_dict(self) -> dict[str, Any]:
- return {
- "quant_method": self.quant_method,
- "bits": self.bits,
- "group_size": self.group_size,
- "modules_to_not_convert": self.modules_to_not_convert,
- }
- @dataclass
- class FourOverSixConfig(QuantizationConfigMixin):
- """
- This is a wrapper class containing all options for quantization with `fouroversix`. In brief,
- Four Over Six is a modification to NVFP4 quantization which adaptively scales the largest value
- in each block of 16 FP4 values to either 4 or 6. Selecting a scale of 6 uses the full range of
- FP4 values, but selecting a scale of 4 allows for a more uniform distribution of quantization
- error. Refer to the original publication for more details: https://arxiv.org/abs/2512.02010.
- Args:
- activation_scale_rule (`str`, *optional*):
- Scaling rule to use when selecting a scale for blocks in activation tensors. If not
- provided, `scale_rule` is used.
- dtype (`str`, default "nvfp4", *optional*, defaults to `"nvfp4"`):
- The data type to use for the layer's weights, activations, and tensors. Can be
- `"nvfp4"` or `"mxfp4"`.
- gradient_scale_rule (`str`, *optional*):
- Scaling rule to use when selecting a scale for blocks in gradient tensors. If not
- provided, `scale_rule` is used.
- keep_master_weights (`bool`, default False, *optional*, defaults to `False`):
- Whether to keep the master weights. If `True`, high-precision weights are kept at all
- times and weights are quantized online in each forward pass. This is useful for
- quantized training.
- matmul_backend (`str`, *optional*):
- The backend to use for matrix multiplications. Can be `"cutlass"` or `"pytorch"`. If
- not provided, CUTLASS will be used if available and PyTorch will be used otherwise.
- output_dtype (`str`, *optional*, defaults to `"bfloat16"`):
- The data type to use for the output of the layer. Can be `"bfloat16"` or `"float16"`.
- quantize_backend (`str`, *optional*):
- The backend to use for quantization. Can be `"cuda"`, `"triton"`, or `"pytorch"`. If
- not provided, the fastest backend will be selected based on your environment, and based
- on the options supported by each backend. Typically, `"cuda"` will be used for
- inference, `"triton"` will be used for training, and `"pytorch"` will be used on
- non-CUDA devices.
- scale_rule (`str`, default "mse", *optional*, defaults to `"mse"`):
- Rule to use when selecting block scales. Can be `"mse"`, `"mae"`, or `"abs_max"` for
- Four Over Six, `"static_6"` for default NVFP4 quantization, or `"static_4"` to scale
- all blocks to a maximum value of 4.
- weight_scale_2d (`bool`, default False, *optional*, defaults to `False`):
- Whether to compute scale factors on weight tensors in 2D blocks. This should be done
- during training.
- weight_scale_rule (`str`, *optional*):
- Scaling rule to use when selecting a scale for blocks in weight tensors. If not
- provided, `scale_rule` is used.
- module_config_overrides (`dict[str, dict[str, Any]]`, *optional*):
- A dictionary of module-specific configuration overrides. Keys should be module names, and
- values should be dictionaries containing the quantization configuration for that module.
- This can be used to override the default configuration for specific modules.
- modules_to_not_convert (`list[str]`, *optional*, defaults to `['lm_head']`):
- The list of modules to exclude from quantization. By default, the `lm_head` is excluded.
- """
- def __init__(
- self,
- activation_scale_rule: str | None = None,
- dtype: str = "nvfp4",
- gradient_scale_rule: str | None = None,
- keep_master_weights: bool = False,
- matmul_backend: str | None = None,
- output_dtype: str | None = "bfloat16",
- quantize_backend: str | None = None,
- scale_rule: str = "mse",
- weight_scale_2d: bool = False,
- weight_scale_rule: str | None = None,
- module_config_overrides: dict[str, dict[str, Any]] | None = None,
- modules_to_not_convert: list[str] | None = ["lm_head"],
- **kwargs,
- ):
- self.quant_method = QuantizationMethod.FOUR_OVER_SIX
- self.activation_scale_rule = activation_scale_rule
- self.dtype = dtype
- self.gradient_scale_rule = gradient_scale_rule
- self.keep_master_weights = keep_master_weights
- self.matmul_backend = matmul_backend
- self.quantize_backend = quantize_backend
- self.output_dtype = output_dtype
- self.scale_rule = scale_rule
- self.weight_scale_2d = weight_scale_2d
- self.weight_scale_rule = weight_scale_rule
- self.module_config_overrides = module_config_overrides
- self.modules_to_not_convert = modules_to_not_convert
- class SinqConfig(QuantizationConfigMixin):
- """
- Quantization config for SINQ / A-SINQ.
- Pass this to:
- AutoModel.from_pretrained(..., quantization_config=SinqConfig(...))
- Args:
- nbits (`int`, default 4):
- Quantization bits for weights.
- group_size (`int`, default 64):
- Group size used in SINQ weight quantization (must be multiple of 8).
- tiling_mode (`str`, default "1D"):
- Tiling mode for SINQ (typically "1D"; "2D" if supported in your backend).
- method (`str`, default "sinq"):
- "sinq" – calibration-free weight-only SINQ
- "asinq" – A-SINQ (activation-aware), not supported in Hugging Face. Please refer to the official SINQ repository.
- modules_to_not_convert (`list[str]`, *optional*):
- List of module names/prefixes to keep in full precision.
- **kwargs:
- Extra user arguments (kept in `_extra_kwargs` for round-tripping).
- """
- def __init__(
- self,
- nbits: int = 4,
- group_size: int = 64,
- tiling_mode: str = "1D",
- method: str = "sinq", # "sinq" | "asinq"
- modules_to_not_convert: list[str] | None = None,
- **kwargs: Any,
- ):
- self.quant_method = QuantizationMethod.SINQ
- self.nbits = nbits
- self.group_size = group_size
- self.tiling_mode = tiling_mode
- self.method = method
- self.modules_to_not_convert = modules_to_not_convert
- self._extra_kwargs: dict[str, Any] = dict(kwargs)
- self.post_init()
- def post_init(self):
- self.nbits = int(self.nbits)
- self.group_size = int(self.group_size)
- self.tiling_mode = str(self.tiling_mode)
- self.method = str(self.method).lower()
- # Validation
- if not isinstance(self.nbits, int):
- raise TypeError("`nbits` must be convertible to an int")
- if not isinstance(self.group_size, int):
- raise TypeError("`group_size` must be convertible to an int")
- if not isinstance(self.tiling_mode, str):
- raise TypeError("`tiling_mode` must be convertible to a string")
- if self.method not in {"sinq", "asinq"}:
- raise ValueError(f"`method` must be either 'sinq' or 'asinq', got {self.method}")
- if self.group_size is not None and self.group_size % 8 != 0:
- logger.warning(
- f"SINQ: group_size={self.group_size} is not a multiple of 8; this may be rejected by the backend."
- )
|