quantizer_torchao.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. from typing import TYPE_CHECKING
  16. from .base import HfQuantizer
  17. from .quantizers_utils import get_module_from_name, should_convert_module
  18. if TYPE_CHECKING:
  19. from ..modeling_utils import PreTrainedModel
  20. from ..utils.quantization_config import TorchAoConfig
  21. from safetensors import safe_open
  22. from ..utils import is_torch_available, is_torchao_available, logging
  23. MIN_TORCH_VERSION = "2.5.0"
  24. if is_torch_available():
  25. from ..core_model_loading import WeightConverter
  26. if is_torch_available():
  27. import torch
  28. if is_torchao_available():
  29. from torchao.prototype.safetensors.safetensors_support import (
  30. flatten_tensor_state_dict,
  31. )
  32. logger = logging.get_logger(__name__)
  33. def _fuzzy_match_size(config_name: str) -> str | None:
  34. """
  35. Extract the size digit from torchao config class names like "Int4WeightOnlyConfig", "Int8WeightOnlyConfig".
  36. Returns the digit as a string if found, otherwise None.
  37. """
  38. match = re.search(r"(\d)weight", config_name.lower())
  39. return match.group(1) if match else None
  40. class TorchAoHfQuantizer(HfQuantizer):
  41. """
  42. Quantizer for torchao: https://github.com/pytorch/ao/
  43. """
  44. requires_calibration = False
  45. quantization_config: "TorchAoConfig"
  46. def __init__(self, quantization_config, **kwargs):
  47. super().__init__(quantization_config, **kwargs)
  48. size_digit = _fuzzy_match_size(type(self.quantization_config.quant_type).__name__)
  49. self.quantized_param_size = 0.5 if size_digit == "4" else 1
  50. def validate_environment(self, *args, **kwargs):
  51. if not is_torchao_available():
  52. raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")
  53. device_map = kwargs.get("device_map")
  54. self.offload_to_cpu = False
  55. if isinstance(device_map, dict):
  56. if ("disk" in device_map.values() or "cpu" in device_map.values()) and len(device_map) > 1:
  57. self.offload_to_cpu = "cpu" in device_map.values()
  58. if self.pre_quantized and "disk" in device_map.values():
  59. raise ValueError(
  60. "You are attempting to perform disk offload with a pre-quantized torchao model "
  61. "This is not supported yet . Please remove the disk device from the device_map."
  62. )
  63. def get_state_dict_and_metadata(self, model):
  64. """
  65. We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
  66. """
  67. return flatten_tensor_state_dict(model.state_dict())
  68. def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
  69. "Return the element size (in bytes) for `param_name`."
  70. if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
  71. return self.quantized_param_size
  72. return super().param_element_size(model, param_name, param)
  73. def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
  74. # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
  75. max_memory = {key: val * 0.9 for key, val in max_memory.items()}
  76. return max_memory
  77. def _process_model_before_weight_loading(self, model: "PreTrainedModel", checkpoint_files=None, **kwargs):
  78. self.modules_to_not_convert = self.get_modules_to_not_convert(
  79. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  80. )
  81. if self.quantization_config.include_input_output_embeddings:
  82. input_emb = model.get_input_embeddings()
  83. input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
  84. output_emb = model.get_output_embeddings()
  85. output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
  86. self.modules_to_not_convert = [
  87. x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
  88. ]
  89. if checkpoint_files is not None:
  90. # Torchao needs access to all metadata later
  91. self.set_metadata(checkpoint_files)
  92. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  93. # check if the param_name is not in self.modules_to_not_convert
  94. if not should_convert_module(param_name, self.modules_to_not_convert):
  95. return False
  96. # we only quantize the weight of nn.Linear and nn.Embedding
  97. module, tensor_name = get_module_from_name(model, param_name)
  98. _QUANTIZABLE = [torch.nn.Linear]
  99. if self.quantization_config.include_input_output_embeddings:
  100. _QUANTIZABLE.append(torch.nn.Embedding)
  101. from torchao.quantization import FqnToConfig, fqn_matches_fqn_config
  102. if isinstance(self.quantization_config.quant_type, FqnToConfig):
  103. module_fqn, _ = param_name.rsplit(".", 1)
  104. if (
  105. fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type)
  106. or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type)
  107. or (
  108. "_default" in self.quantization_config.quant_type.fqn_to_config
  109. and isinstance(module, tuple(_QUANTIZABLE))
  110. )
  111. ):
  112. return True
  113. return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"
  114. def is_serializable(self) -> bool:
  115. return True
  116. @property
  117. def is_trainable(self) -> bool:
  118. # Only 8-bit quantization (e.g. Int8WeightOnly, Int8DynamicActivationInt8Weight) supports training
  119. return _fuzzy_match_size(type(self.quantization_config.quant_type).__name__) == "8"
  120. @property
  121. def is_compileable(self) -> bool:
  122. return True
  123. def set_metadata(self, checkpoint_files: list[str]):
  124. if checkpoint_files[0].endswith(".safetensors"):
  125. metadata = {}
  126. for checkpoint in checkpoint_files:
  127. with safe_open(checkpoint, framework="pt") as f:
  128. metadata_ = f.metadata() or {}
  129. metadata.update(metadata_)
  130. # Save it
  131. self.metadata = metadata
  132. def get_quantize_ops(self):
  133. from ..integrations.torchao import TorchAoQuantize
  134. return TorchAoQuantize(self)
  135. def get_weight_conversions(self):
  136. from ..integrations.torchao import TorchAoDeserialize
  137. if self.pre_quantized:
  138. return [
  139. WeightConverter(
  140. # TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
  141. # note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
  142. # thus, the order of source_patterns is intentional
  143. source_patterns=[
  144. "_weight_qdata",
  145. "_weight_scale_and_zero",
  146. "_weight_scale",
  147. "_weight_zero_point",
  148. "_weight_act_pre_scale",
  149. ],
  150. target_patterns="weight",
  151. operations=[TorchAoDeserialize(self)],
  152. ),
  153. ]
  154. return []