quantizer_fp_quant.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from .base import HfQuantizer
  16. from .quantizers_utils import get_module_from_name
  17. if TYPE_CHECKING:
  18. from ..modeling_utils import PreTrainedModel
  19. from ..utils.quantization_config import FPQuantConfig
  20. from ..utils import is_fp_quant_available, is_qutlass_available, is_torch_available, is_torch_xpu_available, logging
  21. from ..utils.quantization_config import QuantizationConfigMixin
  22. if is_torch_available():
  23. import torch
  24. logger = logging.get_logger(__name__)
  25. class FPQuantHfQuantizer(HfQuantizer):
  26. """
  27. Quantizer for the FP-Quant method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
  28. """
  29. requires_calibration = False
  30. is_qat_trainable = True
  31. quantization_config: "FPQuantConfig"
  32. def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. def validate_environment(self, device_map, **kwargs):
  35. if not torch.cuda.is_available() and not is_torch_xpu_available():
  36. raise NotImplementedError(
  37. "FPQuant quantization is only supported on GPU or Intel XPU. Please use a different quantizer."
  38. )
  39. if not is_qutlass_available() and not self.quantization_config.pseudoquantization:
  40. raise ImportError(
  41. "Using `fp_quant` with real quantization requires a **Blackwell GPU** and qutlass: `git clone https://github.com/IST-DASLab/qutlass.git && cd qutlass && pip install --no-build-isolation .`. You can use `FPQuantConfig(pseudoquantization=True, ...)` to use Triton-based pseudo-quantization. It doesn't provide any speedups but emulates the quantization behavior of the real quantization."
  42. )
  43. if (
  44. self.quantization_config.pseudoquantization
  45. and self.quantization_config.forward_dtype == "nvfp4"
  46. and torch.cuda.is_available()
  47. and torch.cuda.get_device_capability()[0] < 9
  48. ):
  49. raise ValueError(
  50. "NVFP4 pseudoquantization requires a GPU with compute capability >= 9.0 (Hopper or newer) "
  51. "because the Triton kernel uses the `fp8e4nv` type. Please use `forward_dtype='mxfp4'` instead, "
  52. "or use a GPU with compute capability >= 9.0."
  53. )
  54. if self.quantization_config.pseudoquantization:
  55. logger.warning(
  56. "Using pseudo-quantization for FP-Quant. This doesn't provide any speedups but emulates the quantization behavior of the real quantization."
  57. )
  58. if not is_fp_quant_available():
  59. raise ImportError("Using `fp_quant` quantization requires fp_quant: `pip install fp_quant`")
  60. if device_map is None and not self.quantization_config.pseudoquantization:
  61. raise ValueError(
  62. "You are attempting to load a FPQuant model without setting device_map."
  63. " Please set device_map comprised of 'cuda' devices."
  64. )
  65. elif isinstance(device_map, dict):
  66. if (
  67. not self.quantization_config.pseudoquantization
  68. and len(device_map) > 1
  69. and "cpu" in device_map.values()
  70. or "disk" in device_map.values()
  71. ):
  72. raise ValueError(
  73. "You are attempting to load a FPQuant model with a device_map that contains a CPU or disk device."
  74. " This is not supported. Please remove the CPU or disk device from the device_map."
  75. )
  76. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  77. if dtype != torch.bfloat16:
  78. logger.warning_once(
  79. f"Setting dtype to {dtype}, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
  80. )
  81. dtype = torch.bfloat16
  82. return dtype
  83. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  84. from fp_quant import FPQuantLinear
  85. module, tensor_name = get_module_from_name(model, param_name)
  86. if isinstance(module, FPQuantLinear) and tensor_name in ["weight", "qweight", "dqweight"]:
  87. # Only quantize weights of FPQuantLinear modules that are not already quantized
  88. return True
  89. else:
  90. return False
  91. def _process_model_before_weight_loading(
  92. self,
  93. model: "PreTrainedModel",
  94. **kwargs,
  95. ):
  96. from fp_quant import replace_with_fp_quant_linear
  97. from ..integrations.fp_quant import adapt_fp_quant_config
  98. replace_with_fp_quant_linear(
  99. model,
  100. fp_quant_linear_config=adapt_fp_quant_config(self.quantization_config),
  101. )
  102. @property
  103. def is_trainable(self, model: Optional["PreTrainedModel"] = None):
  104. trainable = self.quantization_config.store_master_weights
  105. if not trainable:
  106. logger.warning(
  107. "You are attempting to train a model with FPQuant quantization. This is only supported when `store_master_weights=True`. Please set `store_master_weights=True` to train the model."
  108. )
  109. return trainable
  110. def is_serializable(self):
  111. return True
  112. def get_quantize_ops(self):
  113. from ..integrations.fp_quant import FpQuantQuantize
  114. return FpQuantQuantize(self)
  115. def get_weight_conversions(self):
  116. from ..core_model_loading import WeightConverter
  117. from ..integrations.fp_quant import FpQuantDeserialize
  118. if self.pre_quantized:
  119. if self.quantization_config.pseudoquantization:
  120. return [
  121. WeightConverter(
  122. source_patterns=[".dqweight"],
  123. target_patterns=".dqweight",
  124. operations=[FpQuantDeserialize(self)],
  125. ),
  126. ]
  127. else:
  128. return [
  129. WeightConverter(
  130. source_patterns=[".qweight"],
  131. target_patterns=".qweight",
  132. operations=[FpQuantDeserialize(self)],
  133. ),
  134. ]
  135. return []