quantizer_bnb_8bit.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils.quantization_config import BitsAndBytesConfig
  19. from ..utils import (
  20. ACCELERATE_MIN_VERSION,
  21. BITSANDBYTES_MIN_VERSION,
  22. is_accelerate_available,
  23. is_bitsandbytes_available,
  24. is_torch_available,
  25. is_torch_hpu_available,
  26. is_torch_npu_available,
  27. is_torch_xpu_available,
  28. logging,
  29. )
  30. from .quantizers_utils import get_module_from_name
  31. if is_torch_available():
  32. import torch
  33. from ..core_model_loading import WeightConverter
  34. logger = logging.get_logger(__name__)
  35. class Bnb8BitHfQuantizer(HfQuantizer):
  36. """
  37. 8-bit quantization from bitsandbytes quantization method
  38. """
  39. requires_calibration = False
  40. quantization_config: "BitsAndBytesConfig"
  41. def __init__(self, quantization_config, **kwargs):
  42. super().__init__(quantization_config, **kwargs)
  43. def validate_environment(self, *args, **kwargs):
  44. if not is_accelerate_available():
  45. raise ImportError(
  46. f"Using `bitsandbytes` 8-bit quantization requires accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
  47. )
  48. if not is_bitsandbytes_available():
  49. raise ImportError(
  50. f"Using `bitsandbytes` 8-bit quantization requires bitsandbytes: `pip install -U bitsandbytes>={BITSANDBYTES_MIN_VERSION}`"
  51. )
  52. from ..integrations import validate_bnb_backend_availability
  53. validate_bnb_backend_availability(raise_exception=True)
  54. device_map = kwargs.get("device_map")
  55. if not self.quantization_config.llm_int8_enable_fp32_cpu_offload and isinstance(device_map, dict):
  56. values = set(device_map.values())
  57. if values != {"cpu"} and ("cpu" in values or "disk" in values):
  58. raise ValueError(
  59. "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
  60. "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
  61. "in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to "
  62. "`from_pretrained`. Check "
  63. "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
  64. "for more details. "
  65. )
  66. def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
  67. # need more space for buffers that are created during quantization
  68. max_memory = {key: val * 0.90 for key, val in max_memory.items()}
  69. return max_memory
  70. def update_device_map(self, device_map):
  71. if device_map is None:
  72. if torch.cuda.is_available():
  73. device_map = {"": torch.cuda.current_device()}
  74. elif is_torch_npu_available() and hasattr(torch, "npu"):
  75. device_map = {"": f"npu:{torch.npu.current_device()}"}
  76. elif is_torch_hpu_available() and hasattr(torch, "hpu"):
  77. device_map = {"": f"hpu:{torch.hpu.current_device()}"}
  78. elif is_torch_xpu_available():
  79. device_map = {"": torch.xpu.current_device()}
  80. else:
  81. device_map = {"": "cpu"}
  82. logger.info(
  83. "The device_map was not initialized. "
  84. f"Setting device_map to {device_map}. "
  85. "If you want to use the model for inference, please set device_map ='auto' "
  86. )
  87. return device_map
  88. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  89. "Return the element size (in bytes) for `param_name`."
  90. if self.param_needs_quantization(model, param_name):
  91. # 8-bit
  92. return 1
  93. return super().param_element_size(model, param_name, param)
  94. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  95. import bitsandbytes as bnb
  96. module, name = get_module_from_name(model, param_name)
  97. return isinstance(module, bnb.nn.Linear8bitLt) and name != "bias"
  98. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  99. setattr(model, "is_loaded_in_8bit", True)
  100. model.is_8bit_serializable = self.is_serializable()
  101. return model
  102. def _process_model_before_weight_loading(
  103. self,
  104. model: "PreTrainedModel",
  105. device_map,
  106. **kwargs,
  107. ):
  108. from ..integrations import replace_with_bnb_linear
  109. self.modules_to_not_convert = self.get_modules_to_not_convert(
  110. model, self.quantization_config.llm_int8_skip_modules, model._keep_in_fp32_modules
  111. )
  112. if self.quantization_config.llm_int8_enable_fp32_cpu_offload:
  113. if isinstance(device_map, dict):
  114. keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
  115. self.modules_to_not_convert.extend(keys_on_cpu)
  116. model = replace_with_bnb_linear(
  117. model,
  118. modules_to_not_convert=self.modules_to_not_convert,
  119. quantization_config=self.quantization_config,
  120. pre_quantized=self.pre_quantized,
  121. )
  122. def is_serializable(self):
  123. return True
  124. @property
  125. def is_trainable(self) -> bool:
  126. return True
  127. def _dequantize(self, model, dtype=None):
  128. from ..integrations import dequantize_and_replace
  129. model = dequantize_and_replace(model, quantization_config=self.quantization_config, dtype=dtype)
  130. return model
  131. def get_quantize_ops(self):
  132. from ..integrations.bitsandbytes import Bnb8bitQuantize
  133. return Bnb8bitQuantize(self)
  134. def get_weight_conversions(self):
  135. from ..integrations.bitsandbytes import Bnb8bitDeserialize
  136. if self.pre_quantized:
  137. return [
  138. WeightConverter(
  139. source_patterns=["SCB", "weight_format", "weight"],
  140. target_patterns="weight",
  141. operations=[Bnb8bitDeserialize(self)],
  142. )
  143. ]
  144. return []