fouroversix.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. from ..quantizers.quantizers_utils import get_module_from_name
  3. from ..utils import is_fouroversix_available
  4. if is_fouroversix_available():
  5. from fouroversix import ModelQuantizationConfig
  6. from transformers.utils.quantization_config import FourOverSixConfig
  7. from ..core_model_loading import ConversionOps
  8. class FourOverSixQuantize(ConversionOps):
  9. def __init__(self, hf_quantizer):
  10. self.hf_quantizer = hf_quantizer
  11. def convert(
  12. self,
  13. input_dict: dict[str, torch.Tensor],
  14. model: torch.nn.Module | None = None,
  15. full_layer_name: str | None = None,
  16. missing_keys: list[str] | None = None,
  17. **kwargs,
  18. ) -> dict[str, torch.Tensor]:
  19. """
  20. We need to store some parameters to create the quantized weight. For example, fouroversix
  21. requires 4 values that are stored in the checkpoint to recover the quantized weight. So we
  22. store them in a dict that is stored in hf_quantizer for now as we can't save it in the op
  23. since we create an op per tensor.
  24. """
  25. if self.hf_quantizer.quantization_config.keep_master_weights:
  26. return input_dict
  27. module, _ = get_module_from_name(model, full_layer_name)
  28. module_name = full_layer_name.rsplit(".", 1)[0]
  29. full_parameter_name = list(input_dict.keys())[0]
  30. parameter_name = full_parameter_name.replace(f"{module_name}.", "", 1)
  31. parameter = input_dict[full_parameter_name][0]
  32. quantized_parameters = module.get_quantized_parameters(parameter_name, parameter)
  33. # Delete the high-precision parameters from the module after we used them to create
  34. # the quantized parameters
  35. if hasattr(module, parameter_name):
  36. delattr(module, parameter_name)
  37. # Remove these keys from the missing_keys list since we've deleted them from the model
  38. for key in input_dict:
  39. missing_keys.discard(key)
  40. return {
  41. f"{module_name}.{quantized_key}": quantized_parameters[quantized_key]
  42. for quantized_key in quantized_parameters
  43. }
  44. def adapt_fouroversix_config(config: FourOverSixConfig):
  45. return ModelQuantizationConfig(
  46. activation_scale_rule=config.activation_scale_rule,
  47. dtype=config.dtype,
  48. gradient_scale_rule=config.gradient_scale_rule,
  49. keep_master_weights=config.keep_master_weights,
  50. matmul_backend=config.matmul_backend,
  51. output_dtype=config.output_dtype,
  52. quantize_backend=config.quantize_backend,
  53. scale_rule=config.scale_rule,
  54. weight_scale_2d=config.weight_scale_2d,
  55. weight_scale_rule=config.weight_scale_rule,
  56. modules_to_not_convert=config.modules_to_not_convert,
  57. module_config_overrides=config.module_config_overrides,
  58. )