quantizer_fouroversix.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from typing import TYPE_CHECKING
  2. from ..utils.import_utils import is_fouroversix_available
  3. from .base import HfQuantizer
  4. from .quantizers_utils import get_module_from_name
  5. if TYPE_CHECKING:
  6. from ..modeling_utils import PreTrainedModel
  7. from ..utils.quantization_config import FourOverSixConfig
  8. from ..utils import (
  9. is_torch_available,
  10. )
  11. if is_torch_available():
  12. import torch
  13. class FourOverSixHfQuantizer(HfQuantizer):
  14. """
  15. FP4 quantization with fouroversix.
  16. """
  17. requires_calibration = False
  18. quantization_config: "FourOverSixConfig"
  19. def __init__(self, quantization_config, **kwargs):
  20. super().__init__(quantization_config, **kwargs)
  21. def validate_environment(self, *args, **kwargs):
  22. if not is_fouroversix_available():
  23. raise ImportError(
  24. "Using `fouroversix` requires fouroversix: `pip install fouroversix --no-build-isolation`"
  25. )
  26. def param_element_size(
  27. self,
  28. model: "PreTrainedModel",
  29. param_name: str,
  30. param: "torch.Tensor",
  31. ) -> float:
  32. from fouroversix import QuantizedModule
  33. module, tensor_name = get_module_from_name(model, param_name)
  34. if QuantizedModule.is_quantized_module_type(type(module)):
  35. return module.get_element_size(tensor_name)
  36. return super().param_element_size(model, param_name, param)
  37. def param_needs_quantization(
  38. self,
  39. model: "PreTrainedModel",
  40. param_name: str,
  41. **kwargs,
  42. ) -> bool:
  43. from fouroversix import QuantizedModule
  44. module, tensor_name = get_module_from_name(model, param_name)
  45. return QuantizedModule.is_quantized_module_type(type(module)) and tensor_name in module.parameters_to_quantize
  46. def _process_model_before_weight_loading(
  47. self,
  48. model: "PreTrainedModel",
  49. device_map,
  50. **kwargs,
  51. ):
  52. from fouroversix import QuantizedModule, quantize_model
  53. from ..integrations.fouroversix import adapt_fouroversix_config
  54. quantize_model(
  55. model,
  56. adapt_fouroversix_config(self.quantization_config),
  57. )
  58. # If the model has already been quantized, we need to delete the weight tensor here so that
  59. # it's not expected when parameters are loaded from the checkpoint.
  60. if self.pre_quantized and not self.quantization_config.keep_master_weights:
  61. for _, module in model.named_modules():
  62. if QuantizedModule.is_quantized_module_type(type(module)):
  63. for parameter_name in module.parameters_to_quantize:
  64. delattr(module, parameter_name)
  65. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  66. return model
  67. def is_serializable(self):
  68. return True
  69. @property
  70. def is_trainable(self) -> bool:
  71. return self.quantization_config.keep_master_weights
  72. def get_quantize_ops(self):
  73. from ..integrations.fouroversix import FourOverSixQuantize
  74. return FourOverSixQuantize(self)
  75. def get_weight_conversions(self):
  76. """
  77. Return weight conversions for loading pre-quantized checkpoints of
  78. other pre-quantized models (not fouroversix models). After first use,
  79. the pre_quantized_model_config_type attribute is set to None to ensure
  80. subsequent calls (e.g., during save_pretrained) return an empty list
  81. since, by then, the model will be saved with our framework's format
  82. so weight conversions are no longer needed.
  83. """
  84. from fouroversix import WeightConversions
  85. # pre_quantized_model_config_type is only set if we are loading a
  86. # pre-quantized model so it is not guaranteed to exist.
  87. if hasattr(self.quantization_config, "pre_quantized_model_config_type"):
  88. model_config_type = self.quantization_config.pre_quantized_model_config_type
  89. weight_conversions = WeightConversions.get_weight_conversions(
  90. model_config_type,
  91. )
  92. return weight_conversions
  93. return []