| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import TYPE_CHECKING
- from .base import HfQuantizer
- from .quantizers_utils import get_module_from_name
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..utils import (
- is_accelerate_available,
- is_optimum_quanto_available,
- is_torch_available,
- logging,
- )
- from ..utils.quantization_config import QuantoConfig
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- class QuantoHfQuantizer(HfQuantizer):
- """
- Quantizer for the quanto library
- """
- requires_calibration = False
- quantization_config: "QuantoConfig"
- def __init__(self, quantization_config: QuantoConfig, **kwargs):
- super().__init__(quantization_config, **kwargs)
- map_to_param_size = {
- "int8": 1,
- "float8": 1,
- "int4": 0.5,
- "int2": 0.25,
- }
- self.quantized_param_size = map_to_param_size.get(self.quantization_config.weights, None)
- def validate_environment(self, *args, **kwargs):
- if not is_optimum_quanto_available():
- raise ImportError(
- "Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
- )
- if not is_accelerate_available():
- raise ImportError(
- "Loading an optimum-quanto quantized model requires accelerate library (`pip install accelerate`)"
- )
- device_map = kwargs.get("device_map")
- if isinstance(device_map, dict):
- if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
- raise ValueError(
- "You are attempting to load an model with a device_map that contains a CPU or disk device."
- "This is not supported with quanto when the model is quantized on the fly. "
- "Please remove the CPU or disk device from the device_map."
- )
- if self.quantization_config.activations is not None:
- raise ValueError(
- "We don't support quantizing the activations with transformers library."
- "Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
- )
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
- from optimum.quanto import QModuleMixin
- module, tensor_name = get_module_from_name(model, param_name)
- # We only quantize the weights and the bias is not quantized.
- if isinstance(module, QModuleMixin) and "weight" in tensor_name:
- # if the weights are quantized, don't need to recreate it again with `create_quantized_param`
- return not module.frozen
- else:
- return False
- def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
- max_memory = {key: val * 0.90 for key, val in max_memory.items()}
- return max_memory
- def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
- "Return the element size (in bytes) for `param_name`."
- if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
- return self.quantized_param_size
- return super().param_element_size(model, param_name, param)
- def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
- from ..integrations import replace_with_quanto_layers
- self.modules_to_not_convert = self.get_modules_to_not_convert(
- model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
- )
- model = replace_with_quanto_layers(
- model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
- )
- @property
- def is_trainable(self) -> bool:
- return True
- def is_serializable(self):
- return False
- def get_quantize_ops(self):
- from ..integrations.quanto import QuantoQuantize
- return QuantoQuantize(self)
|