modeling_nanochat.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/nanochat/modular_nanochat.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_nanochat.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ... import initialization as init
  26. from ...activations import ACT2FN
  27. from ...cache_utils import Cache, DynamicCache
  28. from ...generation import GenerationMixin
  29. from ...integrations import use_kernel_func_from_hub, use_kernelized_func
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  33. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import TransformersKwargs, auto_docstring
  37. from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults
  38. from ...utils.output_capturing import capture_outputs
  39. from .configuration_nanochat import NanoChatConfig
  40. class NanoChatRMSNorm(torch.nn.Module):
  41. def __init__(self, eps: float = 1e-6):
  42. super().__init__()
  43. self.eps = eps
  44. def _norm(self, x):
  45. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  46. def forward(self, x):
  47. return self._norm(x.float()).type_as(x)
  48. def extra_repr(self):
  49. return f"eps={self.eps}"
  50. class NanoChatRotaryEmbedding(nn.Module):
  51. inv_freq: torch.Tensor # fix linting for `register_buffer`
  52. def __init__(self, config: NanoChatConfig, device=None):
  53. super().__init__()
  54. self.max_seq_len_cached = config.max_position_embeddings
  55. self.original_max_seq_len = config.max_position_embeddings
  56. self.config = config
  57. self.rope_type = self.config.rope_parameters["rope_type"]
  58. rope_init_fn: Callable = self.compute_default_rope_parameters
  59. if self.rope_type != "default":
  60. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  61. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  62. self.register_buffer("inv_freq", inv_freq, persistent=False)
  63. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  64. @staticmethod
  65. def compute_default_rope_parameters(
  66. config: NanoChatConfig | None = None,
  67. device: Optional["torch.device"] = None,
  68. seq_len: int | None = None,
  69. ) -> tuple["torch.Tensor", float]:
  70. """
  71. Computes the inverse frequencies according to the original RoPE implementation
  72. Args:
  73. config ([`~transformers.PreTrainedConfig`]):
  74. The model configuration.
  75. device (`torch.device`):
  76. The device to use for initialization of the inverse frequencies.
  77. seq_len (`int`, *optional*):
  78. The current sequence length. Unused for this type of RoPE.
  79. Returns:
  80. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  81. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  82. """
  83. base = config.rope_parameters["rope_theta"]
  84. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  85. attention_factor = 1.0 # Unused in this type of RoPE
  86. # Compute the inverse frequencies
  87. inv_freq = 1.0 / (
  88. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  89. )
  90. return inv_freq, attention_factor
  91. @torch.no_grad()
  92. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  93. def forward(self, x, position_ids):
  94. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  95. position_ids_expanded = position_ids[:, None, :].float()
  96. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  97. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  98. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  99. emb = torch.cat((freqs, freqs), dim=-1)
  100. cos = emb.cos() * self.attention_scaling
  101. sin = emb.sin() * self.attention_scaling
  102. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  103. @use_kernel_func_from_hub("rotary_pos_emb")
  104. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  105. """Applies Rotary Position Embedding to the query and key tensors.
  106. Args:
  107. q (`torch.Tensor`): The query tensor.
  108. k (`torch.Tensor`): The key tensor.
  109. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  110. sin (`torch.Tensor`): The sine part of the rotary embedding.
  111. unsqueeze_dim (`int`, *optional*, defaults to 1):
  112. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  113. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  114. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  115. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  116. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  117. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  118. Returns:
  119. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  120. """
  121. cos = cos.unsqueeze(unsqueeze_dim)
  122. sin = sin.unsqueeze(unsqueeze_dim)
  123. q_embed = (q * cos) + (rotate_half(q) * sin)
  124. k_embed = (k * cos) + (rotate_half(k) * sin)
  125. return q_embed, k_embed
  126. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  127. """
  128. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  129. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  130. """
  131. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  132. if n_rep == 1:
  133. return hidden_states
  134. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  135. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  136. def eager_attention_forward(
  137. module: nn.Module,
  138. query: torch.Tensor,
  139. key: torch.Tensor,
  140. value: torch.Tensor,
  141. attention_mask: torch.Tensor | None,
  142. scaling: float,
  143. dropout: float = 0.0,
  144. **kwargs: Unpack[TransformersKwargs],
  145. ):
  146. key_states = repeat_kv(key, module.num_key_value_groups)
  147. value_states = repeat_kv(value, module.num_key_value_groups)
  148. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  149. if attention_mask is not None:
  150. attn_weights = attn_weights + attention_mask
  151. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  152. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  153. attn_output = torch.matmul(attn_weights, value_states)
  154. attn_output = attn_output.transpose(1, 2).contiguous()
  155. return attn_output, attn_weights
  156. def rotate_half(x):
  157. """Rotates half the hidden dims of the input with flipped signs for NanoChat."""
  158. x1 = x[..., : x.shape[-1] // 2]
  159. x2 = x[..., x.shape[-1] // 2 :]
  160. return torch.cat((x2, -x1), dim=-1)
  161. @use_kernelized_func(apply_rotary_pos_emb)
  162. class NanoChatAttention(nn.Module):
  163. """Multi-headed attention from 'Attention Is All You Need' paper"""
  164. def __init__(self, config: NanoChatConfig, layer_idx: int):
  165. super().__init__()
  166. self.config = config
  167. self.layer_idx = layer_idx
  168. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  169. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  170. self.scaling = self.head_dim**-0.5
  171. self.attention_dropout = config.attention_dropout
  172. self.is_causal = True
  173. self.q_proj = nn.Linear(
  174. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  175. )
  176. self.k_proj = nn.Linear(
  177. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  178. )
  179. self.v_proj = nn.Linear(
  180. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  181. )
  182. self.o_proj = nn.Linear(
  183. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  184. )
  185. self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  186. self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  187. def forward(
  188. self,
  189. hidden_states: torch.Tensor,
  190. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  191. attention_mask: torch.Tensor | None = None,
  192. past_key_values: Cache | None = None,
  193. **kwargs: Unpack[TransformersKwargs],
  194. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  195. input_shape = hidden_states.shape[:-1]
  196. hidden_shape = (*input_shape, -1, self.head_dim)
  197. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  198. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  199. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  200. cos, sin = position_embeddings
  201. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  202. # RoPE -> Norm (instead of usual Norm -> RoPE)
  203. query_states = self.q_norm(query_states)
  204. key_states = self.k_norm(key_states)
  205. if past_key_values is not None:
  206. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  207. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  208. self.config._attn_implementation, eager_attention_forward
  209. )
  210. attn_output, attn_weights = attention_interface(
  211. self,
  212. query_states,
  213. key_states,
  214. value_states,
  215. attention_mask,
  216. dropout=0.0 if not self.training else self.attention_dropout,
  217. scaling=self.scaling,
  218. **kwargs,
  219. )
  220. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  221. attn_output = self.o_proj(attn_output)
  222. return attn_output, attn_weights
  223. class NanoChatMLP(nn.Module):
  224. def __init__(self, config):
  225. super().__init__()
  226. self.config = config
  227. self.activation_fn = ACT2FN[config.hidden_act]
  228. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
  229. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. hidden_states = self.fc1(hidden_states)
  232. hidden_states = self.activation_fn(hidden_states)
  233. hidden_states = self.fc2(hidden_states)
  234. return hidden_states
  235. class NanoChatDecoderLayer(GradientCheckpointingLayer):
  236. def __init__(self, config: NanoChatConfig, layer_idx: int):
  237. super().__init__()
  238. self.hidden_size = config.hidden_size
  239. self.self_attn = NanoChatAttention(config=config, layer_idx=layer_idx)
  240. self.mlp = NanoChatMLP(config)
  241. self.input_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  242. self.post_attention_layernorm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  243. def forward(
  244. self,
  245. hidden_states: torch.Tensor,
  246. attention_mask: torch.Tensor | None = None,
  247. position_ids: torch.LongTensor | None = None,
  248. past_key_values: Cache | None = None,
  249. use_cache: bool | None = False,
  250. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  251. **kwargs: Unpack[TransformersKwargs],
  252. ) -> torch.Tensor:
  253. residual = hidden_states
  254. hidden_states = self.input_layernorm(hidden_states)
  255. # Self Attention
  256. hidden_states, _ = self.self_attn(
  257. hidden_states=hidden_states,
  258. attention_mask=attention_mask,
  259. position_ids=position_ids,
  260. past_key_values=past_key_values,
  261. use_cache=use_cache,
  262. position_embeddings=position_embeddings,
  263. **kwargs,
  264. )
  265. hidden_states = residual + hidden_states
  266. # Fully Connected
  267. residual = hidden_states
  268. hidden_states = self.post_attention_layernorm(hidden_states)
  269. hidden_states = self.mlp(hidden_states)
  270. hidden_states = residual + hidden_states
  271. return hidden_states
  272. @auto_docstring
  273. class NanoChatPreTrainedModel(PreTrainedModel):
  274. config: NanoChatConfig
  275. base_model_prefix = "model"
  276. supports_gradient_checkpointing = True
  277. _no_split_modules = ["NanoChatDecoderLayer"]
  278. _skip_keys_device_placement = ["past_key_values"]
  279. _supports_flash_attn = True
  280. _supports_sdpa = True
  281. _supports_flex_attn = True
  282. _can_compile_fullgraph = True
  283. _supports_attention_backend = True
  284. _can_record_outputs = {
  285. "hidden_states": NanoChatDecoderLayer,
  286. "attentions": NanoChatAttention,
  287. }
  288. def _init_weights(self, module: nn.Module) -> None:
  289. super()._init_weights(module)
  290. if isinstance(module, NanoChatAttention):
  291. init.normal_(
  292. module.o_proj.weight,
  293. mean=0.0,
  294. std=self.config.initializer_range / math.sqrt(2 * self.config.num_hidden_layers),
  295. )
  296. @auto_docstring
  297. class NanoChatModel(NanoChatPreTrainedModel):
  298. def __init__(self, config: NanoChatConfig):
  299. super().__init__(config)
  300. self.padding_idx = config.pad_token_id
  301. self.vocab_size = config.vocab_size
  302. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  303. self.layers = nn.ModuleList(
  304. [NanoChatDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  305. )
  306. self.norm = NanoChatRMSNorm(eps=config.rms_norm_eps)
  307. self.rotary_emb = NanoChatRotaryEmbedding(config=config)
  308. self.gradient_checkpointing = False
  309. # Initialize weights and apply final processing
  310. self.post_init()
  311. @merge_with_config_defaults
  312. @capture_outputs
  313. @auto_docstring
  314. def forward(
  315. self,
  316. input_ids: torch.LongTensor | None = None,
  317. attention_mask: torch.Tensor | None = None,
  318. position_ids: torch.LongTensor | None = None,
  319. past_key_values: Cache | None = None,
  320. inputs_embeds: torch.FloatTensor | None = None,
  321. use_cache: bool | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> BaseModelOutputWithPast:
  324. if (input_ids is None) ^ (inputs_embeds is not None):
  325. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  326. if inputs_embeds is None:
  327. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  328. if use_cache and past_key_values is None:
  329. past_key_values = DynamicCache(config=self.config)
  330. if position_ids is None:
  331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  332. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  333. position_ids = position_ids.unsqueeze(0)
  334. causal_mask = create_causal_mask(
  335. config=self.config,
  336. inputs_embeds=inputs_embeds,
  337. attention_mask=attention_mask,
  338. past_key_values=past_key_values,
  339. position_ids=position_ids,
  340. )
  341. hidden_states = inputs_embeds
  342. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  343. hidden_states = self.norm(hidden_states) # Additional norm before the layers
  344. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  345. hidden_states = decoder_layer(
  346. hidden_states,
  347. attention_mask=causal_mask,
  348. position_embeddings=position_embeddings,
  349. position_ids=position_ids,
  350. past_key_values=past_key_values,
  351. **kwargs,
  352. )
  353. hidden_states = self.norm(hidden_states)
  354. return BaseModelOutputWithPast(
  355. last_hidden_state=hidden_states,
  356. past_key_values=past_key_values,
  357. )
  358. @auto_docstring
  359. class NanoChatForCausalLM(NanoChatPreTrainedModel, GenerationMixin):
  360. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  361. _tp_plan = {"lm_head": "colwise_gather_output"}
  362. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  363. def __init__(self, config):
  364. super().__init__(config)
  365. self.model = NanoChatModel(config)
  366. self.vocab_size = config.vocab_size
  367. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  368. # Initialize weights and apply final processing
  369. self.post_init()
  370. @can_return_tuple
  371. @auto_docstring
  372. def forward(
  373. self,
  374. input_ids: torch.LongTensor | None = None,
  375. attention_mask: torch.Tensor | None = None,
  376. position_ids: torch.LongTensor | None = None,
  377. past_key_values: Cache | None = None,
  378. inputs_embeds: torch.FloatTensor | None = None,
  379. labels: torch.LongTensor | None = None,
  380. use_cache: bool | None = None,
  381. logits_to_keep: int | torch.Tensor = 0,
  382. **kwargs: Unpack[TransformersKwargs],
  383. ) -> CausalLMOutputWithPast:
  384. r"""
  385. Example:
  386. ```python
  387. >>> from transformers import AutoTokenizer, AutoModelForCausalLM
  388. >>> model = AutoModelForCausalLM.from_pretrained("karpathy/nanochat-d32")
  389. >>> tokenizer = AutoTokenizer.from_pretrained("karpathy/nanochat-d32")
  390. >>> conversation = [
  391. {"role": "user", "content": "What is the capital of France?"},
  392. ]
  393. >>> inputs = tokenizer.apply_chat_template(
  394. conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
  395. ).to(device)
  396. >>> with torch.no_grad():
  397. >>> outputs = model.generate(**inputs, max_new_tokens=64, do_sample=False)
  398. >>> generated_tokens = outputs[0, inputs["input_ids"].shape[1] :]
  399. >>> output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
  400. ```"""
  401. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  402. outputs: BaseModelOutputWithPast = self.model(
  403. input_ids=input_ids,
  404. attention_mask=attention_mask,
  405. position_ids=position_ids,
  406. past_key_values=past_key_values,
  407. inputs_embeds=inputs_embeds,
  408. use_cache=use_cache,
  409. **kwargs,
  410. )
  411. hidden_states = outputs.last_hidden_state
  412. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  413. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  414. logits = self.lm_head(hidden_states[:, slice_indices, :])
  415. if self.config.final_logit_softcapping is not None:
  416. logits = logits / self.config.final_logit_softcapping
  417. logits = torch.tanh(logits)
  418. logits = logits * self.config.final_logit_softcapping
  419. loss = None
  420. if labels is not None:
  421. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  422. return CausalLMOutputWithPast(
  423. loss=loss,
  424. logits=logits,
  425. past_key_values=outputs.past_key_values,
  426. hidden_states=outputs.hidden_states,
  427. attentions=outputs.attentions,
  428. )
  429. __all__ = ["NanoChatPreTrainedModel", "NanoChatModel", "NanoChatForCausalLM"]