bitnet.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. from ..quantizers.quantizers_utils import should_convert_module
  2. from ..utils import is_torch_available, logging
  3. if is_torch_available():
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. logger = logging.get_logger(__name__)
  8. # the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
  9. VALUES_PER_ITEM = 4
  10. def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
  11. """
  12. Packs a tensor of quantized weights into a compact format using 2 bits per value.
  13. Parameters:
  14. -----------
  15. quantized_weights : torch.Tensor
  16. A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
  17. {0, 1, 2} before being packed.
  18. Returns:
  19. --------
  20. torch.Tensor
  21. A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
  22. """
  23. original_shape = quantized_weights.shape
  24. row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
  25. if len(original_shape) == 1:
  26. packed_tensor_shape = (row_dim,)
  27. else:
  28. packed_tensor_shape = (row_dim, *original_shape[1:])
  29. quantized_weights += 1
  30. packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
  31. unpacked = quantized_weights.to(torch.uint8)
  32. it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
  33. for i in range(it):
  34. start = i * row_dim
  35. end = min(start + row_dim, original_shape[0])
  36. packed[: (end - start)] |= unpacked[start:end] << 2 * i
  37. return packed
  38. @torch.compile
  39. def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
  40. """
  41. Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
  42. Parameters:
  43. -----------
  44. packed : torch.Tensor
  45. A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
  46. dtype : torch.dtype
  47. The dtype of the returned Tensor
  48. Returns:
  49. --------
  50. torch.Tensor
  51. A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
  52. Example:
  53. --------
  54. packed = torch.tensor([[0b10100001, 0b00011000],
  55. [0b10010000, 0b00001010]], dtype=torch.uint8)
  56. # Unpack the values
  57. unpacked = unpack_weights(packed)
  58. # Resulting unpacked tensor
  59. print(unpacked)
  60. # Output: tensor([[ 0, -1],
  61. [-1, 1],
  62. [-1, 1],
  63. [-1, 1],
  64. [ 1, 0],
  65. [ 0, -1],
  66. [ 1, -1],
  67. [ 1, -1]])
  68. Explanation of the example:
  69. ---------------------------
  70. Let's take the first value for example 0b10100001, we will only focus on the first column,
  71. because every element is unpacked across the first dimension
  72. - First 2 bits: `01` → 0 at [0][0]
  73. - Second 2 bits: `00` → -1 at [0][2]
  74. - Third 2 bits: `10` → 1 at [0][4]
  75. - Fourth 2 bits: `10` → 1 at [0][6]
  76. the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
  77. We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
  78. we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
  79. this by subtracting 1 to restore the original ternary values.
  80. """
  81. packed_shape = packed.shape
  82. if len(packed_shape) == 1:
  83. original_row_dim = packed_shape[0] * VALUES_PER_ITEM
  84. unpacked_shape = (original_row_dim,)
  85. else:
  86. original_row_dim = packed_shape[0] * VALUES_PER_ITEM
  87. unpacked_shape = (original_row_dim, *packed_shape[1:])
  88. unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
  89. for i in range(VALUES_PER_ITEM):
  90. start = i * packed_shape[0]
  91. end = start + packed_shape[0]
  92. mask = 3 << (2 * i)
  93. unpacked[start:end] = (packed & mask) >> (2 * i)
  94. return unpacked.to(dtype) - 1
  95. class BitLinear(nn.Module):
  96. def __init__(
  97. self,
  98. in_features: int,
  99. out_features: int,
  100. bias: bool,
  101. device=None,
  102. dtype=None,
  103. use_rms_norm: bool = False,
  104. rms_norm_eps: float = 1e-6,
  105. ):
  106. super().__init__()
  107. self.dtype = dtype
  108. self.in_features = in_features
  109. self.out_features = out_features
  110. self.register_buffer(
  111. "weight",
  112. torch.zeros(
  113. (out_features // VALUES_PER_ITEM, in_features),
  114. dtype=torch.uint8,
  115. device=device,
  116. ),
  117. )
  118. self.register_buffer(
  119. "weight_scale",
  120. torch.ones(
  121. (1),
  122. dtype=dtype,
  123. device=device,
  124. ),
  125. )
  126. if bias:
  127. self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
  128. else:
  129. self.bias = None
  130. # Optional RMSNorm (applied on the activations before quantization).
  131. self.rms_norm = None
  132. if use_rms_norm:
  133. from ..models.llama.modeling_llama import LlamaRMSNorm
  134. self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
  135. @torch.compile
  136. def activation_quant(self, input, num_bits=8):
  137. """
  138. Activation function : Performs symmetric, per-token quantization on the input activations.
  139. Parameters:
  140. -----------
  141. input : torch.Tensor
  142. Input activations to be quantized.
  143. num_bits : int, optional (default=8)
  144. Number of bits to use for quantization, determining the quantization range.
  145. Returns:
  146. --------
  147. result : torch.Tensor
  148. Quantized activation tensor, with values mapped to an `int8` range.
  149. scale : torch.Tensor
  150. The per-channel scaling factors used to quantize the tensor.
  151. """
  152. Qn = -(2 ** (num_bits - 1))
  153. Qp = 2 ** (num_bits - 1) - 1
  154. scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
  155. result = (input * scale).round().clamp(Qn, Qp)
  156. return result.to(torch.int8), scale
  157. @torch.compile
  158. def post_quant_process(self, input, input_scale, weight_scale):
  159. out = input / (input_scale * weight_scale)
  160. return out
  161. def forward(self, input):
  162. # Apply RMSNorm on the input if requested.
  163. if self.rms_norm is not None:
  164. input = self.rms_norm(input)
  165. w = self.weight
  166. w_quant = unpack_weights(w, dtype=self.dtype)
  167. input_quant, input_scale = self.activation_quant(input)
  168. y = F.linear(input_quant.to(self.dtype), w_quant)
  169. y = self.post_quant_process(y, self.weight_scale, input_scale)
  170. if self.bias is not None:
  171. y += self.bias.view(1, -1).expand_as(y)
  172. return y
  173. class WeightQuant(torch.autograd.Function):
  174. """
  175. Implements a custom autograd function for weight quantization.
  176. This performs ternary quantization (-1, 0, 1) based on scaling by the
  177. mean absolute value of the weights. It uses the Straight-Through Estimator
  178. (STE) for the backward pass.
  179. """
  180. @staticmethod
  181. @torch.compile
  182. def forward(ctx, weight):
  183. dtype = weight.dtype
  184. weight = weight.float()
  185. scale = 1.0 / weight.abs().mean().clamp_(min=1e-5)
  186. weight = (weight * scale).round().clamp(-1, 1) / scale
  187. return weight.to(dtype)
  188. @staticmethod
  189. def backward(ctx, grad_output):
  190. grad_input = grad_output.clone()
  191. return grad_input
  192. class ActQuant(torch.autograd.Function):
  193. """
  194. Implements a custom autograd function for activation quantization.
  195. This performs symmetric 8-bit quantization (to the range [-128, 127])
  196. based on the maximum absolute value along the last dimension (per-token/row scaling).
  197. It uses the Straight-Through Estimator (STE) for the backward pass.
  198. """
  199. @staticmethod
  200. @torch.compile
  201. def forward(ctx, activation):
  202. dtype = activation.dtype
  203. activation = activation.float()
  204. scale = 127 / activation.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
  205. activation = (activation * scale).round().clamp(-128, 127) / scale
  206. return activation.to(dtype)
  207. @staticmethod
  208. def backward(ctx, grad_output):
  209. grad_input = grad_output.clone()
  210. return grad_input
  211. class AutoBitLinear(nn.Linear):
  212. def __init__(
  213. self,
  214. in_features: int,
  215. out_features: int,
  216. bias: bool = True,
  217. device=None,
  218. dtype=None,
  219. online_quant: bool = False,
  220. use_rms_norm: bool = False,
  221. rms_norm_eps: float = 1e-6,
  222. ):
  223. super().__init__(in_features, out_features, bias)
  224. self.online_quant = online_quant
  225. # Optional RMSNorm
  226. self.rms_norm = None
  227. if use_rms_norm:
  228. from ..models.llama.modeling_llama import LlamaRMSNorm
  229. self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
  230. if not online_quant:
  231. self.register_buffer(
  232. "weight_scale",
  233. torch.ones(
  234. (1),
  235. dtype=dtype,
  236. device=device,
  237. ),
  238. )
  239. self._register_load_state_dict_pre_hook(self.load_hook)
  240. def load_hook(
  241. self,
  242. state_dict,
  243. prefix,
  244. *args,
  245. **kwargs,
  246. ):
  247. if (prefix + "weight") in state_dict and state_dict[prefix + "weight"].dtype != self.weight.dtype:
  248. state_dict[prefix + "weight"] = unpack_weights(state_dict[prefix + "weight"], dtype=self.weight.dtype)
  249. return state_dict
  250. def forward(self, input):
  251. # Optional RMSNorm on activations prior to quantization.
  252. if self.rms_norm is not None:
  253. input = self.rms_norm(input)
  254. if self.online_quant:
  255. weight = WeightQuant.apply(self.weight)
  256. else:
  257. weight = self.weight
  258. input = ActQuant.apply(input)
  259. output = F.linear(input, weight, self.bias)
  260. if not self.online_quant:
  261. output = output * self.weight_scale
  262. return output
  263. def replace_with_bitnet_linear(model, modules_to_not_convert: list[str] | None = None, quantization_config=None):
  264. """
  265. Public method that replaces the linear layers of the given model with bitnet quantized layers.
  266. Args:
  267. model (`torch.nn.Module`):
  268. The model to convert, can be any `torch.nn.Module` instance.
  269. modules_to_not_convert (`list[str]`, *optional*, defaults to `None`):
  270. A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
  271. converted.
  272. quantization_config (`BitNetConfig`):
  273. The quantization config object that contains the quantization parameters.
  274. """
  275. has_been_replaced = False
  276. # we need this to correctly materialize the weights during quantization
  277. for module_name, module in model.named_modules():
  278. if not should_convert_module(module_name, modules_to_not_convert):
  279. continue
  280. with torch.device("meta"):
  281. if isinstance(module, nn.Linear):
  282. if quantization_config and quantization_config.linear_class == "autobitlinear":
  283. new_module = AutoBitLinear(
  284. in_features=module.in_features,
  285. out_features=module.out_features,
  286. bias=module.bias is not None,
  287. device=module.weight.device,
  288. dtype=module.weight.dtype,
  289. online_quant=(quantization_config.quantization_mode == "online"),
  290. use_rms_norm=quantization_config.use_rms_norm,
  291. rms_norm_eps=quantization_config.rms_norm_eps,
  292. )
  293. if quantization_config.quantization_mode == "offline":
  294. new_module.requires_grad_(False)
  295. else:
  296. new_module = BitLinear(
  297. in_features=module.in_features,
  298. out_features=module.out_features,
  299. bias=module.bias is not None,
  300. device=module.weight.device,
  301. dtype=module.weight.dtype,
  302. use_rms_norm=quantization_config.use_rms_norm if quantization_config else False,
  303. rms_norm_eps=quantization_config.rms_norm_eps if quantization_config else 1e-6,
  304. )
  305. new_module.requires_grad_(False)
  306. model.set_submodule(module_name, new_module)
  307. has_been_replaced = True
  308. if not has_been_replaced:
  309. logger.warning(
  310. "You are loading your model using bitnet but no linear modules were found in your model."
  311. " Please double check your model architecture, or submit an issue on github if you think this is"
  312. " a bug."
  313. )
  314. return model
  315. class BitNetDeserialize:
  316. def __init__(self, hf_quantizer):
  317. self.hf_quantizer = hf_quantizer
  318. def convert(
  319. self,
  320. input_dict: dict[str, list[torch.Tensor]],
  321. model: torch.nn.Module | None = None,
  322. full_layer_name: str | None = None,
  323. **kwargs,
  324. ) -> dict[str, torch.Tensor]:
  325. for key, value in input_dict.items():
  326. if isinstance(value, list):
  327. input_dict[key] = value[0]
  328. key_weight = "weight"
  329. weight = input_dict.pop(key_weight)
  330. from ..quantizers.quantizers_utils import get_module_from_name
  331. needs_unpacking = False
  332. target_dtype = weight.dtype
  333. if model is not None and full_layer_name is not None:
  334. module, _ = get_module_from_name(model, full_layer_name)
  335. if hasattr(module, "out_features") and hasattr(module, "in_features"):
  336. # Packed: shape[0] * VALUES_PER_ITEM == out_features
  337. # Unpacked: shape[0] == out_features
  338. expected_out = module.out_features
  339. actual_out = weight.shape[0]
  340. if actual_out * VALUES_PER_ITEM == expected_out:
  341. needs_unpacking = True
  342. if needs_unpacking:
  343. weight_uint8 = weight.to(torch.uint8)
  344. weight = unpack_weights(weight_uint8, dtype=target_dtype)
  345. return {key_weight: weight}