modeling_bitnet.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/bitnet/modular_bitnet.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_bitnet.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. from collections.abc import Callable
  20. from typing import Optional
  21. import torch
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...generation import GenerationMixin
  26. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  31. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  35. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_bitnet import BitNetConfig
  38. @use_kernel_forward_from_hub("RMSNorm")
  39. class BitNetRMSNorm(nn.Module):
  40. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  41. """
  42. BitNetRMSNorm is equivalent to T5LayerNorm
  43. """
  44. super().__init__()
  45. self.weight = nn.Parameter(torch.ones(hidden_size))
  46. self.variance_epsilon = eps
  47. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  48. input_dtype = hidden_states.dtype
  49. hidden_states = hidden_states.to(torch.float32)
  50. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  51. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  52. return self.weight * hidden_states.to(input_dtype)
  53. def extra_repr(self):
  54. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  55. class BitNetMLP(nn.Module):
  56. def __init__(self, config: BitNetConfig):
  57. super().__init__()
  58. self.config = config
  59. self.hidden_size = config.hidden_size
  60. self.intermediate_size = config.intermediate_size
  61. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  62. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  63. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  64. self.act_fn = ACT2FN[config.hidden_act]
  65. self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)
  66. def forward(self, x):
  67. down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
  68. return down_proj
  69. def rotate_half(x):
  70. """Rotates half the hidden dims of the input."""
  71. x1 = x[..., : x.shape[-1] // 2]
  72. x2 = x[..., x.shape[-1] // 2 :]
  73. return torch.cat((-x2, x1), dim=-1)
  74. @use_kernel_func_from_hub("rotary_pos_emb")
  75. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  76. """Applies Rotary Position Embedding to the query and key tensors.
  77. Args:
  78. q (`torch.Tensor`): The query tensor.
  79. k (`torch.Tensor`): The key tensor.
  80. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  81. sin (`torch.Tensor`): The sine part of the rotary embedding.
  82. unsqueeze_dim (`int`, *optional*, defaults to 1):
  83. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  84. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  85. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  86. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  87. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  88. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  89. Returns:
  90. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  91. """
  92. cos = cos.unsqueeze(unsqueeze_dim)
  93. sin = sin.unsqueeze(unsqueeze_dim)
  94. q_embed = (q * cos) + (rotate_half(q) * sin)
  95. k_embed = (k * cos) + (rotate_half(k) * sin)
  96. return q_embed, k_embed
  97. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  98. """
  99. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  100. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  101. """
  102. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  103. if n_rep == 1:
  104. return hidden_states
  105. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  106. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  107. def eager_attention_forward(
  108. module: nn.Module,
  109. query: torch.Tensor,
  110. key: torch.Tensor,
  111. value: torch.Tensor,
  112. attention_mask: torch.Tensor | None,
  113. scaling: float,
  114. dropout: float = 0.0,
  115. **kwargs: Unpack[TransformersKwargs],
  116. ):
  117. key_states = repeat_kv(key, module.num_key_value_groups)
  118. value_states = repeat_kv(value, module.num_key_value_groups)
  119. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  120. if attention_mask is not None:
  121. attn_weights = attn_weights + attention_mask
  122. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  123. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  124. attn_output = torch.matmul(attn_weights, value_states)
  125. attn_output = attn_output.transpose(1, 2).contiguous()
  126. return attn_output, attn_weights
  127. @use_kernelized_func(apply_rotary_pos_emb)
  128. class BitNetAttention(nn.Module):
  129. """Multi-headed attention from 'Attention Is All You Need' paper"""
  130. def __init__(self, config: BitNetConfig, layer_idx: int):
  131. super().__init__()
  132. self.config = config
  133. self.layer_idx = layer_idx
  134. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  135. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  136. self.scaling = self.head_dim**-0.5
  137. self.attention_dropout = config.attention_dropout
  138. self.is_causal = True
  139. self.q_proj = nn.Linear(
  140. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  141. )
  142. self.k_proj = nn.Linear(
  143. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  144. )
  145. self.v_proj = nn.Linear(
  146. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  147. )
  148. self.o_proj = nn.Linear(
  149. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  150. )
  151. self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  152. def forward(
  153. self,
  154. hidden_states: torch.Tensor,
  155. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  156. attention_mask: torch.Tensor | None,
  157. past_key_values: Cache | None = None,
  158. **kwargs: Unpack[FlashAttentionKwargs],
  159. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  160. input_shape = hidden_states.shape[:-1]
  161. hidden_shape = (*input_shape, -1, self.head_dim)
  162. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  163. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  164. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  165. cos, sin = position_embeddings
  166. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  167. if past_key_values is not None:
  168. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  169. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  170. self.config._attn_implementation, eager_attention_forward
  171. )
  172. attn_output, attn_weights = attention_interface(
  173. self,
  174. query_states,
  175. key_states,
  176. value_states,
  177. attention_mask,
  178. dropout=0.0 if not self.training else self.attention_dropout,
  179. scaling=self.scaling,
  180. **kwargs,
  181. )
  182. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  183. attn_output = self.attn_sub_norm(attn_output) # diff with Llama
  184. attn_output = self.o_proj(attn_output)
  185. return attn_output, attn_weights
  186. class BitNetDecoderLayer(GradientCheckpointingLayer):
  187. def __init__(self, config: BitNetConfig, layer_idx: int):
  188. super().__init__()
  189. self.hidden_size = config.hidden_size
  190. self.self_attn = BitNetAttention(config=config, layer_idx=layer_idx)
  191. self.mlp = BitNetMLP(config)
  192. self.input_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  193. self.post_attention_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  194. def forward(
  195. self,
  196. hidden_states: torch.Tensor,
  197. attention_mask: torch.Tensor | None = None,
  198. position_ids: torch.LongTensor | None = None,
  199. past_key_values: Cache | None = None,
  200. use_cache: bool | None = False,
  201. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  202. **kwargs: Unpack[TransformersKwargs],
  203. ) -> torch.Tensor:
  204. residual = hidden_states
  205. hidden_states = self.input_layernorm(hidden_states)
  206. # Self Attention
  207. hidden_states, _ = self.self_attn(
  208. hidden_states=hidden_states,
  209. attention_mask=attention_mask,
  210. position_ids=position_ids,
  211. past_key_values=past_key_values,
  212. use_cache=use_cache,
  213. position_embeddings=position_embeddings,
  214. **kwargs,
  215. )
  216. hidden_states = residual + hidden_states
  217. # Fully Connected
  218. residual = hidden_states
  219. hidden_states = self.post_attention_layernorm(hidden_states)
  220. hidden_states = self.mlp(hidden_states)
  221. hidden_states = residual + hidden_states
  222. return hidden_states
  223. class BitNetRotaryEmbedding(nn.Module):
  224. inv_freq: torch.Tensor # fix linting for `register_buffer`
  225. def __init__(self, config: BitNetConfig, device=None):
  226. super().__init__()
  227. self.max_seq_len_cached = config.max_position_embeddings
  228. self.original_max_seq_len = config.max_position_embeddings
  229. self.config = config
  230. self.rope_type = self.config.rope_parameters["rope_type"]
  231. rope_init_fn: Callable = self.compute_default_rope_parameters
  232. if self.rope_type != "default":
  233. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  234. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  235. self.register_buffer("inv_freq", inv_freq, persistent=False)
  236. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  237. @staticmethod
  238. def compute_default_rope_parameters(
  239. config: BitNetConfig | None = None,
  240. device: Optional["torch.device"] = None,
  241. seq_len: int | None = None,
  242. ) -> tuple["torch.Tensor", float]:
  243. """
  244. Computes the inverse frequencies according to the original RoPE implementation
  245. Args:
  246. config ([`~transformers.PreTrainedConfig`]):
  247. The model configuration.
  248. device (`torch.device`):
  249. The device to use for initialization of the inverse frequencies.
  250. seq_len (`int`, *optional*):
  251. The current sequence length. Unused for this type of RoPE.
  252. Returns:
  253. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  254. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  255. """
  256. base = config.rope_parameters["rope_theta"]
  257. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  258. attention_factor = 1.0 # Unused in this type of RoPE
  259. # Compute the inverse frequencies
  260. inv_freq = 1.0 / (
  261. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  262. )
  263. return inv_freq, attention_factor
  264. @torch.no_grad()
  265. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  266. def forward(self, x, position_ids):
  267. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  268. position_ids_expanded = position_ids[:, None, :].float()
  269. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  270. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  271. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  272. emb = torch.cat((freqs, freqs), dim=-1)
  273. cos = emb.cos() * self.attention_scaling
  274. sin = emb.sin() * self.attention_scaling
  275. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  276. @auto_docstring
  277. class BitNetPreTrainedModel(PreTrainedModel):
  278. config: BitNetConfig
  279. base_model_prefix = "model"
  280. supports_gradient_checkpointing = True
  281. _no_split_modules = ["BitNetDecoderLayer"]
  282. _skip_keys_device_placement = ["past_key_values"]
  283. _supports_flash_attn = True
  284. _supports_sdpa = True
  285. _supports_flex_attn = True
  286. _can_compile_fullgraph = True
  287. _supports_attention_backend = True
  288. _can_record_outputs = {
  289. "hidden_states": BitNetDecoderLayer,
  290. "attentions": BitNetAttention,
  291. }
  292. @auto_docstring
  293. class BitNetModel(BitNetPreTrainedModel):
  294. def __init__(self, config: BitNetConfig):
  295. super().__init__(config)
  296. self.padding_idx = config.pad_token_id
  297. self.vocab_size = config.vocab_size
  298. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  299. self.layers = nn.ModuleList(
  300. [BitNetDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  301. )
  302. self.norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  303. self.rotary_emb = BitNetRotaryEmbedding(config=config)
  304. self.gradient_checkpointing = False
  305. # Initialize weights and apply final processing
  306. self.post_init()
  307. @merge_with_config_defaults
  308. @capture_outputs
  309. @auto_docstring
  310. def forward(
  311. self,
  312. input_ids: torch.LongTensor | None = None,
  313. attention_mask: torch.Tensor | None = None,
  314. position_ids: torch.LongTensor | None = None,
  315. past_key_values: Cache | None = None,
  316. inputs_embeds: torch.FloatTensor | None = None,
  317. use_cache: bool | None = None,
  318. **kwargs: Unpack[TransformersKwargs],
  319. ) -> BaseModelOutputWithPast:
  320. if (input_ids is None) ^ (inputs_embeds is not None):
  321. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  322. if inputs_embeds is None:
  323. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  324. if use_cache and past_key_values is None:
  325. past_key_values = DynamicCache(config=self.config)
  326. if position_ids is None:
  327. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  328. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  329. position_ids = position_ids.unsqueeze(0)
  330. causal_mask = create_causal_mask(
  331. config=self.config,
  332. inputs_embeds=inputs_embeds,
  333. attention_mask=attention_mask,
  334. past_key_values=past_key_values,
  335. position_ids=position_ids,
  336. )
  337. hidden_states = inputs_embeds
  338. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  339. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  340. hidden_states = decoder_layer(
  341. hidden_states,
  342. attention_mask=causal_mask,
  343. position_embeddings=position_embeddings,
  344. position_ids=position_ids,
  345. past_key_values=past_key_values,
  346. use_cache=use_cache,
  347. **kwargs,
  348. )
  349. hidden_states = self.norm(hidden_states)
  350. return BaseModelOutputWithPast(
  351. last_hidden_state=hidden_states,
  352. past_key_values=past_key_values,
  353. )
  354. @auto_docstring
  355. class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
  356. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  357. _tp_plan = None
  358. _pp_plan = None
  359. def __init__(self, config):
  360. super().__init__(config)
  361. self.model = BitNetModel(config)
  362. self.vocab_size = config.vocab_size
  363. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  364. # Initialize weights and apply final processing
  365. self.post_init()
  366. @can_return_tuple
  367. @auto_docstring
  368. def forward(
  369. self,
  370. input_ids: torch.LongTensor | None = None,
  371. attention_mask: torch.Tensor | None = None,
  372. position_ids: torch.LongTensor | None = None,
  373. past_key_values: Cache | None = None,
  374. inputs_embeds: torch.FloatTensor | None = None,
  375. labels: torch.LongTensor | None = None,
  376. use_cache: bool | None = None,
  377. logits_to_keep: int | torch.Tensor = 0,
  378. **kwargs: Unpack[TransformersKwargs],
  379. ) -> CausalLMOutputWithPast:
  380. r"""
  381. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  382. Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
  383. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  384. (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
  385. Example:
  386. ```python
  387. >>> from transformers import AutoTokenizer, BitNetForCausalLM
  388. >>> model = BitNetForCausalLM.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  389. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  390. >>> prompt = f'<|begin_of_text|>User: Hey, are you conscious? Can you talk to me?<|eot_id|>Assistant: '
  391. >>> inputs = tokenizer(prompt, return_tensors="pt")
  392. >>> # Generate
  393. >>> generate_ids = model.generate(inputs.input_ids, max_length=100)
  394. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  395. "User: Hey, are you conscious? Can you talk to me?Assistant: No, I'm not conscious. I'm an artificial intelligence designed to assist with information and tasks. How can I help you today?"
  396. ```"""
  397. outputs: BaseModelOutputWithPast = self.model(
  398. input_ids=input_ids,
  399. attention_mask=attention_mask,
  400. position_ids=position_ids,
  401. past_key_values=past_key_values,
  402. inputs_embeds=inputs_embeds,
  403. use_cache=use_cache,
  404. **kwargs,
  405. )
  406. hidden_states = outputs.last_hidden_state
  407. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  408. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  409. logits = self.lm_head(hidden_states[:, slice_indices, :])
  410. loss = None
  411. if labels is not None:
  412. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  413. return CausalLMOutputWithPast(
  414. loss=loss,
  415. logits=logits,
  416. past_key_values=outputs.past_key_values,
  417. hidden_states=outputs.hidden_states,
  418. attentions=outputs.attentions,
  419. )
  420. __all__ = ["BitNetForCausalLM", "BitNetModel", "BitNetPreTrainedModel"]