quantizer_awq.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib.metadata
  15. from typing import TYPE_CHECKING
  16. from packaging import version
  17. from .base import HfQuantizer
  18. if TYPE_CHECKING:
  19. from ..modeling_utils import PreTrainedModel
  20. from ..utils.quantization_config import AwqConfig
  21. from ..utils import is_accelerate_available, is_gptqmodel_available, is_torch_available, logging
  22. from ..utils.quantization_config import AwqBackend
  23. if is_torch_available():
  24. import torch
  25. logger = logging.get_logger(__name__)
  26. class AwqQuantizer(HfQuantizer):
  27. """
  28. 4-bit quantization for Activation-aware Weight Quantization(AWQ) (https://huggingface.co/papers/2306.00978)
  29. """
  30. # AWQ requires data calibration - we support only inference
  31. requires_calibration = True
  32. quantization_config: "AwqConfig"
  33. def __init__(self, quantization_config, **kwargs):
  34. super().__init__(quantization_config, **kwargs)
  35. def validate_environment(self, **kwargs):
  36. if not is_gptqmodel_available():
  37. raise ImportError(
  38. "Loading an AWQ quantized model requires gptqmodel. Please install it with `pip install gptqmodel`"
  39. )
  40. if not is_accelerate_available():
  41. raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
  42. def update_dtype(self, dtype):
  43. if dtype == torch.bfloat16 and (torch.cuda.is_available() or torch.xpu.is_available()):
  44. logger.warning(
  45. "`torch.bfloat16` is not supported for AWQ CUDA/XPU kernels yet. Casting to `torch.float16`."
  46. )
  47. dtype = torch.float16
  48. elif dtype != torch.float16 and (torch.cuda.is_available() or torch.xpu.is_available()):
  49. logger.warning("We suggest you to set `dtype=torch.float16` for better efficiency on CUDA/XPU with AWQ.")
  50. return dtype
  51. def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
  52. from ..integrations import replace_quantization_scales, replace_with_awq_linear
  53. self.modules_to_not_convert = self.get_modules_to_not_convert(
  54. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules, add_default_skips=True
  55. )
  56. model = replace_with_awq_linear(
  57. model,
  58. quantization_config=self.quantization_config,
  59. modules_to_not_convert=self.modules_to_not_convert,
  60. device_map=kwargs.get("device_map"),
  61. )
  62. model = replace_quantization_scales(model, model.config.model_type)
  63. def _process_model_after_weight_loading(self, model, **kwargs):
  64. from gptqmodel.utils.model import hf_gptqmodel_post_init
  65. hf_gptqmodel_post_init(model, use_act_order=self.quantization_config.desc_act)
  66. def is_serializable(self):
  67. if self.quantization_config.backend in [AwqBackend.EXLLAMA_V1, AwqBackend.EXLLAMA_V2]:
  68. logger.warning("You cannot save an AWQ model that uses Exllama backend!")
  69. return False
  70. return True
  71. @property
  72. def is_trainable(self):
  73. return version.parse(importlib.metadata.version("gptqmodel")) >= version.parse("5.0.0")