quantizer_bnb_4bit.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from .base import HfQuantizer
  16. from .quantizers_utils import get_module_from_name
  17. if TYPE_CHECKING:
  18. from ..modeling_utils import PreTrainedModel
  19. from ..utils.quantization_config import BitsAndBytesConfig
  20. from ..utils import (
  21. ACCELERATE_MIN_VERSION,
  22. BITSANDBYTES_MIN_VERSION,
  23. is_accelerate_available,
  24. is_bitsandbytes_available,
  25. is_torch_available,
  26. is_torch_hpu_available,
  27. is_torch_npu_available,
  28. is_torch_xpu_available,
  29. logging,
  30. )
  31. if is_torch_available():
  32. import torch
  33. from ..core_model_loading import WeightConverter
  34. logger = logging.get_logger(__name__)
  35. class Bnb4BitHfQuantizer(HfQuantizer):
  36. """
  37. 4-bit quantization from bitsandbytes quantization method
  38. """
  39. requires_calibration = False
  40. quantization_config: "BitsAndBytesConfig"
  41. def __init__(self, quantization_config, **kwargs):
  42. super().__init__(quantization_config, **kwargs)
  43. def validate_environment(self, *args, **kwargs):
  44. if not is_accelerate_available():
  45. raise ImportError(
  46. f"Using `bitsandbytes` 4-bit quantization requires accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  47. )
  48. if not is_bitsandbytes_available():
  49. raise ImportError(
  50. f"Using `bitsandbytes` 4-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>={BITSANDBYTES_MIN_VERSION}`"
  51. )
  52. from ..integrations import validate_bnb_backend_availability
  53. validate_bnb_backend_availability(raise_exception=True)
  54. device_map = kwargs.get("device_map")
  55. if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
  56. values = set(device_map.values())
  57. if values != {"cpu"} and ("cpu" in values or "disk" in values):
  58. raise ValueError(
  59. "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
  60. "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
  61. "in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to "
  62. "`from_pretrained`. Check "
  63. "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
  64. "for more details. "
  65. )
  66. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  67. "Return the element size (in bytes) for `param_name`."
  68. if self.param_needs_quantization(model, param_name):
  69. # 4 bit
  70. return 0.5
  71. return super().param_element_size(model, param_name, param)
  72. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  73. import bitsandbytes as bnb
  74. module, name = get_module_from_name(model, param_name)
  75. return isinstance(module, bnb.nn.Linear4bit) and name != "bias"
  76. def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
  77. # need more space for buffers that are created during quantization
  78. max_memory = {key: val * 0.90 for key, val in max_memory.items()}
  79. return max_memory
  80. def update_device_map(self, device_map):
  81. if device_map is None:
  82. if torch.cuda.is_available():
  83. device_map = {"": torch.cuda.current_device()}
  84. elif is_torch_npu_available() and hasattr(torch, "npu"):
  85. device_map = {"": f"npu:{torch.npu.current_device()}"}
  86. elif is_torch_hpu_available() and hasattr(torch, "hpu"):
  87. device_map = {"": f"hpu:{torch.hpu.current_device()}"}
  88. elif is_torch_xpu_available():
  89. device_map = {"": torch.xpu.current_device()}
  90. else:
  91. device_map = {"": "cpu"}
  92. logger.info(
  93. "The device_map was not initialized. "
  94. f"Setting device_map to {device_map}. "
  95. "If you want to use the model for inference, please set device_map ='auto' "
  96. )
  97. return device_map
  98. def _process_model_before_weight_loading(
  99. self,
  100. model: "PreTrainedModel",
  101. device_map,
  102. **kwargs,
  103. ):
  104. from ..integrations import replace_with_bnb_linear
  105. self.modules_to_not_convert = self.get_modules_to_not_convert(
  106. model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
  107. )
  108. if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
  109. if isinstance(device_map, dict):
  110. keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
  111. self.modules_to_not_convert.extend(keys_on_cpu)
  112. model = replace_with_bnb_linear(
  113. model,
  114. modules_to_not_convert=self.modules_to_not_convert,
  115. quantization_config=self.quantization_config,
  116. pre_quantized=self.pre_quantized,
  117. )
  118. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  119. setattr(model, "is_loaded_in_4bit", True)
  120. setattr(model, "is_4bit_serializable", self.is_serializable())
  121. return model
  122. def is_serializable(self):
  123. return True
  124. @property
  125. def is_trainable(self) -> bool:
  126. return True
  127. def _dequantize(self, model, dtype=None):
  128. from ..integrations import dequantize_and_replace
  129. model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
  130. return model
  131. def get_quantize_ops(self):
  132. from ..integrations.bitsandbytes import Bnb4bitQuantize
  133. return Bnb4bitQuantize(self)
  134. def get_weight_conversions(self):
  135. from ..integrations.bitsandbytes import Bnb4bitDeserialize
  136. if self.pre_quantized:
  137. return [
  138. WeightConverter(
  139. source_patterns=[
  140. "weight.nested_absmax",
  141. "weight.nested_quant_map",
  142. "weight.quant_map",
  143. "weight.absmax",
  144. "weight.quant_state.bitsandbytes__nf4",
  145. "weight.quant_state.bitsandbytes__fp4",
  146. "weight",
  147. ],
  148. target_patterns="weight",
  149. operations=[Bnb4bitDeserialize(self)],
  150. )
  151. ]
  152. return []