modular_bitnet.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright 2025 The BitNet Team and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. """PyTorch BitNet model."""
  14. from collections.abc import Callable
  15. import torch
  16. from ...cache_utils import Cache
  17. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  18. from ...modeling_outputs import CausalLMOutputWithPast
  19. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  20. from ...processing_utils import Unpack
  21. from ...utils import logging
  22. from ..gemma.modeling_gemma import GemmaMLP
  23. from ..llama.modeling_llama import (
  24. LlamaAttention,
  25. LlamaDecoderLayer,
  26. LlamaForCausalLM,
  27. LlamaModel,
  28. LlamaRMSNorm,
  29. apply_rotary_pos_emb,
  30. eager_attention_forward,
  31. )
  32. from .configuration_bitnet import BitNetConfig
  33. logger = logging.get_logger(__name__)
  34. class BitNetRMSNorm(LlamaRMSNorm):
  35. pass
  36. class BitNetMLP(GemmaMLP):
  37. def __init__(self, config: BitNetConfig):
  38. super().__init__(config)
  39. self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)
  40. def forward(self, x):
  41. down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
  42. return down_proj
  43. class BitNetAttention(LlamaAttention):
  44. def __init__(self, config: BitNetConfig, layer_idx: int):
  45. super().__init__(config, layer_idx)
  46. self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  47. def forward(
  48. self,
  49. hidden_states: torch.Tensor,
  50. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  51. attention_mask: torch.Tensor | None,
  52. past_key_values: Cache | None = None,
  53. **kwargs: Unpack[FlashAttentionKwargs],
  54. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  55. input_shape = hidden_states.shape[:-1]
  56. hidden_shape = (*input_shape, -1, self.head_dim)
  57. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  58. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  59. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  60. cos, sin = position_embeddings
  61. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  62. if past_key_values is not None:
  63. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  64. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  65. self.config._attn_implementation, eager_attention_forward
  66. )
  67. attn_output, attn_weights = attention_interface(
  68. self,
  69. query_states,
  70. key_states,
  71. value_states,
  72. attention_mask,
  73. dropout=0.0 if not self.training else self.attention_dropout,
  74. scaling=self.scaling,
  75. **kwargs,
  76. )
  77. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  78. attn_output = self.attn_sub_norm(attn_output) # diff with Llama
  79. attn_output = self.o_proj(attn_output)
  80. return attn_output, attn_weights
  81. class BitNetDecoderLayer(LlamaDecoderLayer):
  82. pass
  83. class BitNetModel(LlamaModel):
  84. pass
  85. class BitNetForCausalLM(LlamaForCausalLM):
  86. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  87. _tp_plan = None
  88. _pp_plan = None
  89. def forward(
  90. self,
  91. **super_kwargs,
  92. ) -> CausalLMOutputWithPast:
  93. r"""
  94. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  95. Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
  96. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  97. (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
  98. Example:
  99. ```python
  100. >>> from transformers import AutoTokenizer, BitNetForCausalLM
  101. >>> model = BitNetForCausalLM.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  102. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/bitnet-b1.58-2B-4T")
  103. >>> prompt = f'<|begin_of_text|>User: Hey, are you conscious? Can you talk to me?<|eot_id|>Assistant: '
  104. >>> inputs = tokenizer(prompt, return_tensors="pt")
  105. >>> # Generate
  106. >>> generate_ids = model.generate(inputs.input_ids, max_length=100)
  107. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  108. "User: Hey, are you conscious? Can you talk to me?Assistant: No, I'm not conscious. I'm an artificial intelligence designed to assist with information and tasks. How can I help you today?"
  109. ```"""
  110. return super().forward(**super_kwargs)
  111. __all__ = [
  112. "BitNetForCausalLM",
  113. "BitNetModel",
  114. "BitNetPreTrainedModel", # noqa: F822
  115. ]