quantizer_finegrained_fp8.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from typing import TYPE_CHECKING
  2. from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
  3. from .base import HfQuantizer
  4. from .quantizers_utils import get_module_from_name
  5. if is_torch_available():
  6. import torch
  7. if TYPE_CHECKING:
  8. from ..modeling_utils import PreTrainedModel
  9. from ..utils.quantization_config import FineGrainedFP8Config
  10. logger = logging.get_logger(__name__)
  11. class FineGrainedFP8HfQuantizer(HfQuantizer):
  12. """
  13. FP8 quantization implementation supporting both standard and MoE models.
  14. Supports both e4m3fn formats based on platform.
  15. """
  16. requires_calibration = False
  17. quantization_config: "FineGrainedFP8Config"
  18. def __init__(self, quantization_config, **kwargs):
  19. super().__init__(quantization_config, **kwargs)
  20. def validate_environment(self, *args, **kwargs):
  21. if not is_accelerate_available():
  22. raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
  23. if self.quantization_config.dequantize:
  24. return
  25. if not torch.cuda.is_available() and not is_torch_xpu_available():
  26. if self.pre_quantized:
  27. logger.warning_once(
  28. "Using FP8 quantized models requires a GPU or XPU, we will default to dequantizing the model to bf16 since no GPU or XPU is available"
  29. )
  30. self.quantization_config.dequantize = True
  31. return
  32. else:
  33. raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
  34. if torch.cuda.is_available():
  35. compute_capability = torch.cuda.get_device_capability()
  36. major, minor = compute_capability
  37. if (major < 8) or (major == 8 and minor < 9):
  38. logger.warning_once(
  39. "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
  40. f", actual = `{major}.{minor}`. We will default to dequantizing the model to bf16. Feel free "
  41. f"to use a different quantization method like bitsandbytes or torchao"
  42. )
  43. self.quantization_config.dequantize = True
  44. return
  45. device_map = kwargs.get("device_map")
  46. if device_map is None:
  47. logger.warning_once(
  48. "You have loaded an FP8 model on CPU and have a CUDA or XPU device available, make sure to set "
  49. "your model on a GPU or XPU device in order to run your model. To remove this warning, "
  50. "pass device_map = 'cuda' or 'xpu'. "
  51. )
  52. elif isinstance(device_map, dict):
  53. if (
  54. not self.pre_quantized
  55. and len(device_map) > 1
  56. and "cpu" in device_map.values()
  57. or "disk" in device_map.values()
  58. ):
  59. raise ValueError(
  60. "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
  61. "This is not supported when the model is quantized on the fly. "
  62. "Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
  63. )
  64. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  65. from ..integrations.finegrained_fp8 import FP8Experts, FP8Linear
  66. module, tensor_name = get_module_from_name(model, param_name)
  67. if isinstance(module, (FP8Linear, FP8Experts)):
  68. if self.pre_quantized or tensor_name == "bias":
  69. return False
  70. else:
  71. return True
  72. return False
  73. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  74. "Return the element size (in bytes) for `param_name`."
  75. if self.param_needs_quantization(model, param_name):
  76. # 8 bit, this is neeed as when `pre_quantized`` is False, we don't set the dtype of the FP8Linear in order to correctly load the weights
  77. return 1
  78. return super().param_element_size(model, param_name, param)
  79. def _process_model_before_weight_loading(
  80. self,
  81. model: "PreTrainedModel",
  82. **kwargs,
  83. ):
  84. from ..integrations.finegrained_fp8 import replace_with_fp8_linear
  85. self.modules_to_not_convert = self.get_modules_to_not_convert(
  86. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  87. )
  88. model = replace_with_fp8_linear(
  89. model,
  90. modules_to_not_convert=self.modules_to_not_convert,
  91. quantization_config=self.quantization_config,
  92. pre_quantized=self.pre_quantized,
  93. )
  94. def update_tp_plan(self, config):
  95. if "Qwen3" in config.__class__.__name__:
  96. text_plan = {
  97. "layers.*.self_attn.q_proj.weight": "colwise",
  98. "layers.*.self_attn.q_proj.weight_scale_inv": "colwise",
  99. "layers.*.self_attn.k_proj.weight": "colwise",
  100. "layers.*.self_attn.k_proj.weight_scale_inv": "colwise",
  101. "layers.*.self_attn.v_proj.weight": "colwise",
  102. "layers.*.self_attn.v_proj.weight_scale_inv": "colwise",
  103. "layers.*.self_attn.o_proj.weight": "rowwise",
  104. "layers.*.self_attn.o_proj.weight_scale_inv": "rowwise",
  105. "layers.*.mlp.gate_proj.weight": "colwise",
  106. "layers.*.mlp.gate_proj.weight_scale_inv": "colwise",
  107. "layers.*.mlp.up_proj.weight": "colwise",
  108. "layers.*.mlp.up_proj.weight_scale_inv": "colwise",
  109. "layers.*.mlp.down_proj.weight": "rowwise",
  110. "layers.*.mlp.down_proj.weight_scale_inv": "rowwise",
  111. }
  112. config.base_model_tp_plan = text_plan
  113. return config
  114. def is_serializable(self):
  115. return True
  116. @property
  117. def is_trainable(self) -> bool:
  118. return False
  119. @property
  120. def is_compileable(self) -> bool:
  121. return True
  122. def get_quantize_ops(self):
  123. from ..integrations.finegrained_fp8 import Fp8Quantize
  124. return Fp8Quantize(self)
  125. def get_weight_conversions(self):
  126. from ..core_model_loading import WeightConverter
  127. from ..integrations.finegrained_fp8 import Fp8Dequantize
  128. if self.pre_quantized and self.quantization_config.dequantize:
  129. return [
  130. # either use the dollar sign, or permute the source patterns to start matching against the scales first
  131. # We also collect the activation scales, they will not be used
  132. WeightConverter(
  133. source_patterns=["weight$", "weight_scale_inv", "activation_scale"],
  134. target_patterns="weight",
  135. operations=[Fp8Dequantize(self)],
  136. )
  137. ]
  138. return []