bitsandbytes.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import inspect
  2. from ..core_model_loading import ConversionOps
  3. from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
  4. from ..utils import (
  5. get_available_devices,
  6. is_accelerate_available,
  7. is_bitsandbytes_available,
  8. is_torch_available,
  9. logging,
  10. )
  11. if is_bitsandbytes_available():
  12. import bitsandbytes as bnb
  13. if is_torch_available():
  14. import torch
  15. import torch.nn as nn
  16. from ..pytorch_utils import Conv1D
  17. if is_accelerate_available():
  18. import accelerate
  19. from accelerate.hooks import add_hook_to_module, remove_hook_from_module
  20. logger = logging.get_logger(__name__)
  21. class Bnb4bitQuantize(ConversionOps):
  22. def __init__(self, hf_quantizer):
  23. self.hf_quantizer = hf_quantizer
  24. def convert(
  25. self,
  26. input_dict: dict[str, list[torch.Tensor]],
  27. full_layer_name: str | None = None,
  28. model: torch.nn.Module | None = None,
  29. **kwargs,
  30. ) -> dict[str, torch.Tensor]:
  31. """
  32. we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
  33. """
  34. value = list(input_dict.values())[0]
  35. value = value[0]
  36. # update param name to get the weights instead of the quantized stats
  37. module, _ = get_module_from_name(model, full_layer_name)
  38. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  39. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  40. if issubclass(module.source_cls, Conv1D):
  41. value = value.T
  42. old_value = model.get_parameter_or_buffer(full_layer_name)
  43. new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
  44. module._is_hf_initialized = True
  45. return {full_layer_name: new_value}
  46. class Bnb4bitDeserialize(ConversionOps):
  47. def __init__(self, hf_quantizer):
  48. self.hf_quantizer = hf_quantizer
  49. def convert(
  50. self,
  51. input_dict: dict[str, list[torch.Tensor]],
  52. model: torch.nn.Module | None = None,
  53. full_layer_name: str | None = None,
  54. **kwargs,
  55. ) -> dict[str, torch.Tensor]:
  56. """
  57. Deserialization of bnb keys. We need 6 keys to recreate the quantized weights
  58. """
  59. if len(input_dict) == 1:
  60. return input_dict
  61. for key, value in input_dict.items():
  62. if isinstance(value, list):
  63. input_dict[key] = value[0]
  64. key_weight = "weight"
  65. weight = input_dict.pop(key_weight)
  66. module, _ = get_module_from_name(model, full_layer_name)
  67. new_value = bnb.nn.Params4bit.from_prequantized(
  68. data=weight,
  69. quantized_stats=input_dict,
  70. requires_grad=False,
  71. device=weight.device,
  72. module=module,
  73. )
  74. module._is_hf_initialized = True
  75. return {key_weight: new_value}
  76. class Bnb8bitQuantize(ConversionOps):
  77. def __init__(self, hf_quantizer):
  78. self.hf_quantizer = hf_quantizer
  79. def convert(
  80. self,
  81. input_dict: dict[str, list[torch.Tensor]],
  82. model: torch.nn.Module | None = None,
  83. full_layer_name: str | None = None,
  84. **kwargs,
  85. ) -> dict[str, torch.Tensor]:
  86. value = list(input_dict.values())[0]
  87. value = value[0] if isinstance(value, list) else value
  88. module, _ = get_module_from_name(model, full_layer_name)
  89. # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
  90. # Since weights are saved in the correct "orientation", we skip transposing when loading.
  91. if issubclass(module.source_cls, Conv1D):
  92. value = value.T
  93. value_device = value.device
  94. kwargs = model.get_parameter_or_buffer(full_layer_name).__dict__
  95. kwargs.pop("SCB", None)
  96. new_value = bnb.nn.Int8Params(value.to("cpu"), requires_grad=False, **kwargs).to(value_device)
  97. return {full_layer_name: new_value}
  98. class Bnb8bitDeserialize(ConversionOps):
  99. def __init__(self, hf_quantizer):
  100. self.hf_quantizer = hf_quantizer
  101. def convert(
  102. self,
  103. input_dict: dict[str, list[torch.Tensor]],
  104. model: torch.nn.Module | None = None,
  105. full_layer_name: str | None = None,
  106. **kwargs,
  107. ) -> dict[str, torch.Tensor]:
  108. """
  109. Deserialization of bnb keys.
  110. """
  111. if len(input_dict) == 1:
  112. # special case when we only fetched the weight
  113. # since we collected keys, we need to return it like that
  114. return input_dict
  115. for key, value in input_dict.items():
  116. if isinstance(value, list):
  117. input_dict[key] = value[0]
  118. module, _ = get_module_from_name(model, full_layer_name)
  119. key_weight = "weight"
  120. weight = input_dict[key_weight]
  121. kwargs = model.get_parameter_or_buffer(full_layer_name).__dict__
  122. kwargs["SCB"] = input_dict["SCB"]
  123. new_value = bnb.nn.Int8Params(weight, requires_grad=False, **kwargs).to(weight.device)
  124. module._is_hf_initialized = True
  125. return {key_weight: new_value}
  126. def replace_with_bnb_linear(
  127. model: torch.nn.Module,
  128. modules_to_not_convert: list[str] | None = None,
  129. quantization_config=None,
  130. pre_quantized=False,
  131. ):
  132. """
  133. A helper function to replace all `torch.nn.Linear` modules by bnb modules from the `bitsandbytes` library.
  134. Args:
  135. model (`torch.nn.Module`):
  136. The model to convert, can be any `torch.nn.Module` instance.
  137. modules_to_not_convert (`list[str]`, defaults to `None`):
  138. A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
  139. converted.
  140. quantization_config (`BitsAndBytesConfig`):
  141. The quantization config object that contains the quantization parameters.
  142. pre_quantized (`book`, defaults to `False`):
  143. Whether the model is pre-quantized or not
  144. """
  145. has_been_replaced = False
  146. # we need this to correctly materialize the weights during quantization
  147. for module_name, module in model.named_modules():
  148. if not should_convert_module(module_name, modules_to_not_convert):
  149. continue
  150. new_module = None
  151. with torch.device("meta"):
  152. if isinstance(module, Conv1D) or type(module) is nn.Linear:
  153. if isinstance(module, Conv1D):
  154. in_features, out_features = module.weight.shape
  155. else:
  156. in_features = module.in_features
  157. out_features = module.out_features
  158. if quantization_config.quantization_method() == "llm_int8":
  159. new_module = bnb.nn.Linear8bitLt(
  160. in_features,
  161. out_features,
  162. module.bias is not None,
  163. has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
  164. threshold=quantization_config.llm_int8_threshold,
  165. )
  166. if pre_quantized:
  167. # this is kind of an edge case when supporting both loading and quantization ...
  168. # we need to set the right dtype as we cast the checkpoint with the dtype of the meta model
  169. new_module.weight.data = new_module.weight.data.to(dtype=torch.int8)
  170. else:
  171. new_module = bnb.nn.Linear4bit(
  172. in_features,
  173. out_features,
  174. module.bias is not None,
  175. quantization_config.bnb_4bit_compute_dtype,
  176. compress_statistics=quantization_config.bnb_4bit_use_double_quant,
  177. quant_type=quantization_config.bnb_4bit_quant_type,
  178. quant_storage=quantization_config.bnb_4bit_quant_storage,
  179. )
  180. if pre_quantized:
  181. # same here
  182. new_module.weight.data = new_module.weight.data.to(
  183. dtype=quantization_config.bnb_4bit_quant_storage
  184. )
  185. if new_module is not None:
  186. # Store the module class in case we need to transpose the weight later
  187. new_module.source_cls = type(module)
  188. # Force requires grad to False to avoid unexpected errors
  189. new_module.requires_grad_(False)
  190. model.set_submodule(module_name, new_module)
  191. has_been_replaced = True
  192. if not has_been_replaced:
  193. logger.warning(
  194. "You are loading your model using eetq but no linear modules were found in your model."
  195. " Please double check your model architecture, or submit an issue on github if you think this is"
  196. " a bug."
  197. )
  198. return model
  199. # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
  200. def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
  201. """
  202. Helper function to dequantize 4bit or 8bit bnb weights.
  203. If the weight is not a bnb quantized weight, it will be returned as is.
  204. """
  205. if not isinstance(weight, torch.nn.Parameter):
  206. raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
  207. cls_name = weight.__class__.__name__
  208. if cls_name not in ("Params4bit", "Int8Params"):
  209. return weight
  210. if cls_name == "Params4bit":
  211. output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
  212. return output_tensor
  213. if state.SCB is None:
  214. state.SCB = weight.SCB
  215. if hasattr(bnb.functional, "int8_vectorwise_dequant"):
  216. # Use bitsandbytes API if available (requires v0.45.0+)
  217. dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
  218. else:
  219. # Multiply by (scale/127) to dequantize.
  220. dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
  221. return dequantized
  222. def _create_accelerate_new_hook(old_hook):
  223. r"""
  224. Creates a new hook based on the old hook. Use it only if you know what you are doing !
  225. This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
  226. with some changes
  227. """
  228. old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
  229. old_hook_attr = old_hook.__dict__
  230. filtered_old_hook_attr = {}
  231. old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
  232. for k in old_hook_attr:
  233. if k in old_hook_init_signature.parameters:
  234. filtered_old_hook_attr[k] = old_hook_attr[k]
  235. new_hook = old_hook_cls(**filtered_old_hook_attr)
  236. return new_hook
  237. def dequantize_and_replace(model, quantization_config=None, dtype=None):
  238. """
  239. Converts a quantized model into its dequantized original version. The newly converted model will have
  240. some performance drop compared to the original model before quantization - use it only for specific usecases
  241. such as QLoRA adapters merging.
  242. Returns the converted model.
  243. """
  244. quant_method = quantization_config.quantization_method()
  245. target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
  246. for module_name, module in model.named_modules():
  247. if isinstance(module, target_cls):
  248. with torch.device("meta"):
  249. bias = getattr(module, "bias", None)
  250. new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
  251. state = module.state if quant_method == "llm_int8" else None
  252. new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
  253. weight = dequantize_bnb_weight(module.weight, state)
  254. if dtype is None:
  255. logger.warning_once(
  256. f"The modules are dequantized in {weight.dtype}. If you want to change the dtype, please specify `dtype` in `dequantize`. "
  257. )
  258. else:
  259. logger.warning_once(f"The modules are dequantized in {weight.dtype} and casted to {dtype}.")
  260. weight = weight.to(dtype)
  261. new_module.weight = torch.nn.Parameter(weight)
  262. if bias is not None:
  263. new_module.bias = bias
  264. if hasattr(module, "_hf_hook"):
  265. old_hook = module._hf_hook
  266. new_hook = _create_accelerate_new_hook(old_hook)
  267. remove_hook_from_module(module)
  268. add_hook_to_module(new_module, new_hook)
  269. new_module.to(module.weight.device)
  270. model.set_submodule(module_name, new_module)
  271. has_been_replaced = True
  272. if not has_been_replaced:
  273. logger.warning(
  274. "For some reason the model has not been properly dequantized. You might see unexpected behavior."
  275. )
  276. return model
  277. def validate_bnb_backend_availability(raise_exception=False):
  278. """
  279. Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
  280. """
  281. bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
  282. available_devices = set(get_available_devices())
  283. if not available_devices.intersection(bnb_supported_devices):
  284. if raise_exception:
  285. err_msg = (
  286. f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices}`. "
  287. "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation"
  288. )
  289. raise RuntimeError(err_msg)
  290. logger.warning("No supported devices found for bitsandbytes")
  291. return False
  292. return True