modular_nanochat.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from collections.abc import Callable
  16. import torch
  17. import torch.nn as nn
  18. from ... import initialization as init
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...masking_utils import create_causal_mask
  21. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, auto_docstring
  25. from ..clip.modeling_clip import CLIPMLP
  26. from ..gemma2.modeling_gemma2 import Gemma2ForCausalLM
  27. from ..llama.modeling_llama import (
  28. LlamaDecoderLayer,
  29. LlamaModel,
  30. LlamaPreTrainedModel,
  31. LlamaRotaryEmbedding,
  32. apply_rotary_pos_emb,
  33. eager_attention_forward,
  34. )
  35. from ..llama4.modeling_llama4 import Llama4TextL2Norm
  36. from ..qwen3.modeling_qwen3 import Qwen3Attention
  37. from .configuration_nanochat import NanoChatConfig
  38. class NanoChatRMSNorm(Llama4TextL2Norm):
  39. pass
  40. class NanoChatRotaryEmbedding(LlamaRotaryEmbedding):
  41. pass
  42. def rotate_half(x):
  43. """Rotates half the hidden dims of the input with flipped signs for NanoChat."""
  44. x1 = x[..., : x.shape[-1] // 2]
  45. x2 = x[..., x.shape[-1] // 2 :]
  46. return torch.cat((x2, -x1), dim=-1)
  47. class NanoChatAttention(Qwen3Attention):
  48. def __init__(self, config: NanoChatConfig, layer_idx: int):
  49. super().__init__(config, layer_idx)
  50. del self.sliding_window
  51. del self.layer_type
  52. self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  53. self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  54. def forward(
  55. self,
  56. hidden_states: torch.Tensor,
  57. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  58. attention_mask: torch.Tensor | None = None,
  59. past_key_values: Cache | None = None,
  60. **kwargs: Unpack[TransformersKwargs],
  61. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  62. input_shape = hidden_states.shape[:-1]
  63. hidden_shape = (*input_shape, -1, self.head_dim)
  64. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  65. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  66. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  67. cos, sin = position_embeddings
  68. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  69. # RoPE -> Norm (instead of usual Norm -> RoPE)
  70. query_states = self.q_norm(query_states)
  71. key_states = self.k_norm(key_states)
  72. if past_key_values is not None:
  73. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  74. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  75. self.config._attn_implementation, eager_attention_forward
  76. )
  77. attn_output, attn_weights = attention_interface(
  78. self,
  79. query_states,
  80. key_states,
  81. value_states,
  82. attention_mask,
  83. dropout=0.0 if not self.training else self.attention_dropout,
  84. scaling=self.scaling,
  85. **kwargs,
  86. )
  87. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  88. attn_output = self.o_proj(attn_output)
  89. return attn_output, attn_weights
  90. class NanoChatMLP(CLIPMLP):
  91. def __init__(self, config):
  92. super().__init__(config)
  93. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
  94. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  95. class NanoChatDecoderLayer(LlamaDecoderLayer):
  96. def __init__(self, config: NanoChatConfig, layer_idx: int):
  97. super().__init__()
  98. self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  99. self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  100. @auto_docstring
  101. class NanoChatPreTrainedModel(LlamaPreTrainedModel):
  102. def _init_weights(self, module: nn.Module) -> None:
  103. PreTrainedModel._init_weights(self, module)
  104. if isinstance(module, NanoChatAttention):
  105. init.normal_(
  106. module.o_proj.weight,
  107. mean=0.0,
  108. std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
  109. )
  110. @auto_docstring
  111. class NanoChatModel(LlamaModel):
  112. def __init__(self, config: NanoChatConfig):
  113. super().__init__(config)
  114. self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  115. def forward(
  116. self,
  117. input_ids: torch.LongTensor | None = None,
  118. attention_mask: torch.Tensor | None = None,
  119. position_ids: torch.LongTensor | None = None,
  120. past_key_values: Cache | None = None,
  121. inputs_embeds: torch.FloatTensor | None = None,
  122. use_cache: bool | None = None,
  123. **kwargs: Unpack[TransformersKwargs],
  124. ) -> BaseModelOutputWithPast:
  125. if (input_ids is None) ^ (inputs_embeds is not None):
  126. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  127. if inputs_embeds is None:
  128. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  129. if use_cache and past_key_values is None:
  130. past_key_values = DynamicCache(config=self.config)
  131. if position_ids is None:
  132. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  133. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  134. position_ids = position_ids.unsqueeze(0)
  135. causal_mask = create_causal_mask(
  136. config=self.config,
  137. inputs_embeds=inputs_embeds,
  138. attention_mask=attention_mask,
  139. past_key_values=past_key_values,
  140. position_ids=position_ids,
  141. )
  142. hidden_states = inputs_embeds
  143. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  144. hidden_states = self.norm(hidden_states) # Additional norm before the layers
  145. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  146. hidden_states = decoder_layer(
  147. hidden_states,
  148. attention_mask=causal_mask,
  149. position_embeddings=position_embeddings,
  150. position_ids=position_ids,
  151. past_key_values=past_key_values,
  152. **kwargs,
  153. )
  154. hidden_states = self.norm(hidden_states)
  155. return BaseModelOutputWithPast(
  156. last_hidden_state=hidden_states,
  157. past_key_values=past_key_values,
  158. )
  159. @auto_docstring
  160. class NanoChatForCausalLM(Gemma2ForCausalLM):
  161. _tp_plan = {"lm_head": "colwise_gather_output"}
  162. def forward(self, **super_kwargs) -> CausalLMOutputWithPast:
  163. r"""
  164. Example:
  165. ```python
  166. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  167. >>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
  168. >>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
  169. >>> conversation = [
  170. {"role": "user", "content": "What is the capital of France?"},
  171. ]
  172. >>> inputs = tokenizer.apply_chat_template(
  173. conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
  174. ).to(device)
  175. >>> with torch.no_grad():
  176. >>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
  177. >>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
  178. >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
  179. ```"""
  180. super().forward(**super_kwargs)
  181. __all__ = [
  182. "NanoChatPreTrainedModel",
  183. "NanoChatModel",
  184. "NanoChatForCausalLM",
  185. ]