quantizer_higgs.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils.logging import tqdm
  16. from .base import HfQuantizer
  17. from .quantizers_utils import get_module_from_name
  18. if TYPE_CHECKING:
  19. from ..modeling_utils import PreTrainedModel
  20. from ..utils.quantization_config import HiggsConfig
  21. from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging
  22. from ..utils.quantization_config import QuantizationConfigMixin
  23. if is_torch_available():
  24. import torch
  25. logger = logging.get_logger(__name__)
  26. class HiggsHfQuantizer(HfQuantizer):
  27. """
  28. Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models.
  29. """
  30. requires_calibration = False
  31. quantization_config: "HiggsConfig"
  32. def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. def validate_environment(self, device_map, **kwargs):
  35. if not torch.cuda.is_available():
  36. raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.")
  37. if not is_accelerate_available():
  38. raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")
  39. if not is_flute_available():
  40. raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`")
  41. if not is_hadamard_available():
  42. raise ImportError(
  43. "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
  44. )
  45. if device_map is None:
  46. raise ValueError(
  47. "You are attempting to load a HIGGS model without setting device_map."
  48. " Please set device_map comprised of 'cuda' devices."
  49. )
  50. elif isinstance(device_map, dict):
  51. if "cpu" in device_map.values() or "disk" in device_map.values():
  52. raise ValueError(
  53. "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
  54. " This is not supported. Please remove the CPU or disk device from the device_map."
  55. )
  56. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  57. if dtype != torch.float16 and dtype != torch.bfloat16:
  58. raise ValueError(
  59. f"Invalid `dtype` {dtype}. HIGGS quantization only supports `dtype=torch.float16` or `dtype=torch.bfloat16`."
  60. )
  61. return dtype
  62. # TODO: to remove
  63. # Kept here in case we see some interest in adding support for it
  64. # def create_quantized_param(
  65. # self,
  66. # model: "PreTrainedModel",
  67. # param_value: "torch.Tensor",
  68. # param_name: str,
  69. # target_device: "torch.device",
  70. # **kwargs,
  71. # ):
  72. # from ..integrations import quantize_with_higgs
  73. # flute_dict = quantize_with_higgs(
  74. # param_value.to(target_device),
  75. # self.quantization_config.bits,
  76. # self.quantization_config.p,
  77. # self.quantization_config.group_size,
  78. # self.quantization_config.hadamard_size,
  79. # )
  80. # del param_value
  81. # module, _ = get_module_from_name(model, param_name)
  82. # module_name = ".".join(param_name.split(".")[:-1])
  83. # for key, value in flute_dict.items():
  84. # if key in module._parameters:
  85. # module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
  86. # elif key in module._buffers:
  87. # module._buffers[key] = torch.nn.Buffer(value)
  88. # elif key == "tune_metadata":
  89. # module.tune_metadata = value
  90. # self.quantization_config.tune_metadata[module_name] = value.to_dict()
  91. # else:
  92. # raise ValueError(f"Unexpected key {key} in module {module}")
  93. def _process_model_before_weight_loading(
  94. self,
  95. model: "PreTrainedModel",
  96. **kwargs,
  97. ):
  98. from ..integrations import replace_with_higgs_linear
  99. self.modules_to_not_convert = self.get_modules_to_not_convert(
  100. model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
  101. )
  102. replace_with_higgs_linear(
  103. model,
  104. quantization_config=self.quantization_config,
  105. modules_to_not_convert=self.modules_to_not_convert,
  106. )
  107. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  108. from flute.tune import TuneMetaData, maybe_tune_and_repack
  109. from flute.utils import make_workspace_streamk
  110. from ..integrations import HiggsLinear
  111. flute_workspaces = {}
  112. flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)}
  113. for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False):
  114. # Every HiggsLinear needs a "workspace": a buffer for the unpacking operation.
  115. # This buffer needs to be on the same device as the weights, but can be reused across modules otherwise.
  116. if module.weight.device not in flute_workspaces:
  117. flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device)
  118. module.workspace = flute_workspaces[module.weight.device]
  119. # FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors).
  120. # If the model is loaded on a different device than the one it was saved on, we need to repack the weights.
  121. module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name])
  122. module.weight.data, module.tune_metadata = maybe_tune_and_repack(
  123. weight=module.weight.data,
  124. scales=module.scales.data,
  125. metadata=module.tune_metadata,
  126. )
  127. self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict()
  128. @property
  129. def is_trainable(self) -> bool:
  130. return False
  131. def is_serializable(self):
  132. return True
  133. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  134. from ..integrations import HiggsLinear
  135. module, tensor_name = get_module_from_name(model, param_name)
  136. if isinstance(module, HiggsLinear) and tensor_name == "weight":
  137. # Only quantize weights of HiggsLinear modules that are not already quantized
  138. return True
  139. else:
  140. return False
  141. def _dequantize(self, model):
  142. from ..integrations import dequantize_higgs
  143. model = dequantize_higgs(model)
  144. return model