quantizer_quanto.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from .base import HfQuantizer
  16. from .quantizers_utils import get_module_from_name
  17. if TYPE_CHECKING:
  18. from ..modeling_utils import PreTrainedModel
  19. from ..utils import (
  20. is_accelerate_available,
  21. is_optimum_quanto_available,
  22. is_torch_available,
  23. logging,
  24. )
  25. from ..utils.quantization_config import QuantoConfig
  26. if is_torch_available():
  27. import torch
  28. logger = logging.get_logger(__name__)
  29. class QuantoHfQuantizer(HfQuantizer):
  30. """
  31. Quantizer for the quanto library
  32. """
  33. requires_calibration = False
  34. quantization_config: "QuantoConfig"
  35. def __init__(self, quantization_config: QuantoConfig, **kwargs):
  36. super().__init__(quantization_config, **kwargs)
  37. map_to_param_size = {
  38. "int8": 1,
  39. "float8": 1,
  40. "int4": 0.5,
  41. "int2": 0.25,
  42. }
  43. self.quantized_param_size = map_to_param_size.get(self.quantization_config.weights, None)
  44. def validate_environment(self, *args, **kwargs):
  45. if not is_optimum_quanto_available():
  46. raise ImportError(
  47. "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
  48. )
  49. if not is_accelerate_available():
  50. raise ImportError(
  51. "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
  52. )
  53. device_map = kwargs.get("device_map")
  54. if isinstance(device_map, dict):
  55. if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
  56. raise ValueError(
  57. "You are attempting to load an model with a device_map that contains a CPU or disk device."
  58. "This is not supported with quanto when the model is quantized on the fly. "
  59. "Please remove the CPU or disk device from the device_map."
  60. )
  61. if self.quantization_config.activations is not None:
  62. raise ValueError(
  63. "We don't support quantizing the activations with transformers library."
  64. "Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
  65. )
  66. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  67. from optimum.quanto import QModuleMixin
  68. module, tensor_name = get_module_from_name(model, param_name)
  69. # We only quantize the weights and the bias is not quantized.
  70. if isinstance(module, QModuleMixin) and "weight" in tensor_name:
  71. # if the weights are quantized, don't need to recreate it again with `create_quantized_param`
  72. return not module.frozen
  73. else:
  74. return False
  75. def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
  76. max_memory = {key: val * 0.90 for key, val in max_memory.items()}
  77. return max_memory
  78. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  79. "Return the element size (in bytes) for `param_name`."
  80. if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
  81. return self.quantized_param_size
  82. return super().param_element_size(model, param_name, param)
  83. def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
  84. from ..integrations import replace_with_quanto_layers
  85. self.modules_to_not_convert = self.get_modules_to_not_convert(
  86. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  87. )
  88. model = replace_with_quanto_layers(
  89. model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
  90. )
  91. @property
  92. def is_trainable(self) -> bool:
  93. return True
  94. def is_serializable(self):
  95. return False
  96. def get_quantize_ops(self):
  97. from ..integrations.quanto import QuantoQuantize
  98. return QuantoQuantize(self)