modeling_nemotron.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  2. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Nemotron model."""
  16. import math
  17. from collections.abc import Callable
  18. from typing import Optional
  19. import torch
  20. import torch.nn.functional as F
  21. from torch import Size, Tensor, nn
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, StaticCache
  25. from ...generation import GenerationMixin
  26. from ...masking_utils import create_causal_mask
  27. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  28. from ...modeling_layers import (
  29. GenericForQuestionAnswering,
  30. GenericForSequenceClassification,
  31. GenericForTokenClassification,
  32. GradientCheckpointingLayer,
  33. )
  34. from ...modeling_outputs import (
  35. BaseModelOutputWithPast,
  36. CausalLMOutputWithPast,
  37. )
  38. from ...modeling_rope_utils import (
  39. ROPE_INIT_FUNCTIONS,
  40. dynamic_rope_update,
  41. )
  42. from ...modeling_utils import PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  45. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  46. from ...utils.output_capturing import capture_outputs
  47. from .configuration_nemotron import NemotronConfig
  48. logger = logging.get_logger(__name__)
  49. def _cast_if_autocast_enabled(device_type, *args):
  50. if not torch.is_autocast_enabled():
  51. return args
  52. else:
  53. target_dtype = torch.get_autocast_dtype(device_type)
  54. return torch.amp.autocast_mode._cast(args, device_type, target_dtype)
  55. class NemotronLayerNorm1P(nn.LayerNorm):
  56. def __init__(
  57. self,
  58. normalized_shape: int | list[int] | Size,
  59. eps: float = 1e-5,
  60. elementwise_affine: bool = True,
  61. bias: bool = True,
  62. device=None,
  63. dtype=None,
  64. ):
  65. super().__init__(normalized_shape, eps, elementwise_affine, bias, device, dtype)
  66. def forward(self, input: Tensor) -> Tensor:
  67. device_type = input.device.type if input.device.type != "mps" else "cpu"
  68. args = _cast_if_autocast_enabled(
  69. device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps
  70. )
  71. with maybe_autocast(device_type=input.device.type, enabled=False):
  72. return F.layer_norm(*args)
  73. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  74. class NemotronRotaryEmbedding(nn.Module):
  75. inv_freq: torch.Tensor # fix linting for `register_buffer`
  76. def __init__(self, config: NemotronConfig, device=None):
  77. super().__init__()
  78. self.max_seq_len_cached = config.max_position_embeddings
  79. self.original_max_seq_len = config.max_position_embeddings
  80. self.config = config
  81. self.rope_type = self.config.rope_parameters["rope_type"]
  82. rope_init_fn: Callable = self.compute_default_rope_parameters
  83. if self.rope_type != "default":
  84. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  85. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  86. self.register_buffer("inv_freq", inv_freq, persistent=False)
  87. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  88. @staticmethod
  89. # Ignore copy
  90. def compute_default_rope_parameters(
  91. config: NemotronConfig | None = None,
  92. device: Optional["torch.device"] = None,
  93. seq_len: int | None = None,
  94. ) -> tuple["torch.Tensor", float]:
  95. """
  96. Computes the inverse frequencies according to the original RoPE implementation
  97. Args:
  98. config ([`~transformers.PreTrainedConfig`]):
  99. The model configuration.
  100. device (`torch.device`):
  101. The device to use for initialization of the inverse frequencies.
  102. seq_len (`int`, *optional*):
  103. The current sequence length. Unused for this type of RoPE.
  104. Returns:
  105. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  106. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  107. """
  108. base = config.rope_parameters["rope_theta"]
  109. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  110. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  111. dim = int(head_dim * partial_rotary_factor)
  112. attention_factor = 1.0 # Unused in this type of RoPE
  113. # Compute the inverse frequencies
  114. inv_freq = 1.0 / (
  115. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  116. )
  117. return inv_freq, attention_factor
  118. @torch.no_grad()
  119. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  120. def forward(self, x, position_ids):
  121. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  122. position_ids_expanded = position_ids[:, None, :].float()
  123. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  124. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  125. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  126. emb = torch.cat((freqs, freqs), dim=-1)
  127. cos = emb.cos() * self.attention_scaling
  128. sin = emb.sin() * self.attention_scaling
  129. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  130. # Copied from transformers.models.llama.modeling_llama.rotate_half
  131. def rotate_half(x):
  132. """Rotates half the hidden dims of the input."""
  133. x1 = x[..., : x.shape[-1] // 2]
  134. x2 = x[..., x.shape[-1] // 2 :]
  135. return torch.cat((-x2, x1), dim=-1)
  136. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  137. """Applies Rotary Position Embedding to the query and key tensors.
  138. Args:
  139. q (`torch.Tensor`): The query tensor.
  140. k (`torch.Tensor`): The key tensor.
  141. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  142. sin (`torch.Tensor`): The sine part of the rotary embedding.
  143. unsqueeze_dim (`int`, *optional*, defaults to 1):
  144. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  145. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  146. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  147. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  148. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  149. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  150. Returns:
  151. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  152. """
  153. cos = cos.unsqueeze(unsqueeze_dim)
  154. sin = sin.unsqueeze(unsqueeze_dim)
  155. rot_dim = cos.shape[-1]
  156. # If q_pass/k_pass is empty, rotary pos embedding is applied to all tensor q/k
  157. q, q_pass = q[..., :rot_dim], q[..., rot_dim:]
  158. k, k_pass = k[..., :rot_dim], k[..., rot_dim:]
  159. q_embed = (q * cos) + (rotate_half(q) * sin)
  160. k_embed = (k * cos) + (rotate_half(k) * sin)
  161. return torch.cat((q_embed, q_pass), dim=-1), torch.cat((k_embed, k_pass), dim=-1)
  162. class NemotronMLP(nn.Module):
  163. def __init__(self, config):
  164. super().__init__()
  165. self.config = config
  166. self.hidden_size = config.hidden_size
  167. self.intermediate_size = config.intermediate_size
  168. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  169. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  170. self.act_fn = ACT2FN[config.hidden_act]
  171. def forward(self, x):
  172. return self.down_proj(self.act_fn(self.up_proj(x)))
  173. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  174. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  175. """
  176. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  177. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  178. """
  179. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  180. if n_rep == 1:
  181. return hidden_states
  182. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  183. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  184. class NemotronAttention(nn.Module):
  185. """Multi-headed attention from 'Attention Is All You Need' paper"""
  186. def __init__(self, config: NemotronConfig, layer_idx: int | None = None):
  187. super().__init__()
  188. self.config = config
  189. self.layer_idx = layer_idx
  190. if layer_idx is None:
  191. logger.warning_once(
  192. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  193. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  194. "when creating this class."
  195. )
  196. self.attention_dropout = config.attention_dropout
  197. self.hidden_size = config.hidden_size
  198. self.num_heads = config.num_attention_heads
  199. self.head_dim = config.head_dim
  200. self.num_key_value_heads = config.num_key_value_heads
  201. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  202. self.max_position_embeddings = config.max_position_embeddings
  203. self.partial_rotary_factor = config.rope_parameters["partial_rotary_factor"]
  204. self.is_causal = True
  205. self.rotary_emb = NemotronRotaryEmbedding(config=config)
  206. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  207. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  208. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  209. self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
  210. def forward(
  211. self,
  212. hidden_states: torch.Tensor,
  213. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  214. attention_mask: torch.Tensor | None = None,
  215. position_ids: torch.LongTensor | None = None,
  216. past_key_values: Cache | None = None,
  217. output_attentions: bool = False,
  218. use_cache: bool = False,
  219. **kwargs,
  220. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  221. bsz, q_len, _ = hidden_states.size()
  222. query_states = self.q_proj(hidden_states)
  223. key_states = self.k_proj(hidden_states)
  224. value_states = self.v_proj(hidden_states)
  225. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  226. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  227. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  228. cos, sin = position_embeddings
  229. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  230. if past_key_values is not None:
  231. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  232. key_states = repeat_kv(key_states, self.num_key_value_groups)
  233. value_states = repeat_kv(value_states, self.num_key_value_groups)
  234. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  235. if attention_mask is not None:
  236. attn_weights = attn_weights + attention_mask
  237. # upcast attention to fp32
  238. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  239. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  240. attn_output = torch.matmul(attn_weights, value_states)
  241. attn_output = attn_output.transpose(1, 2).contiguous()
  242. attn_output = attn_output.reshape(bsz, q_len, -1)
  243. attn_output = self.o_proj(attn_output)
  244. return attn_output, attn_weights
  245. # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  246. # TODO cyril: modular
  247. class NemotronFlashAttention2(NemotronAttention):
  248. """
  249. Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
  250. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  251. flash attention and deal with padding tokens in case the input contains any of them.
  252. """
  253. def __init__(self, *args, **kwargs):
  254. super().__init__(*args, **kwargs)
  255. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  256. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  257. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  258. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  259. def forward(
  260. self,
  261. hidden_states: torch.Tensor,
  262. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  263. attention_mask: torch.LongTensor | None = None,
  264. position_ids: torch.LongTensor | None = None,
  265. past_key_values: Cache | None = None,
  266. output_attentions: bool = False,
  267. use_cache: bool = False,
  268. **kwargs,
  269. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  270. if isinstance(past_key_values, StaticCache):
  271. raise ValueError(
  272. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  273. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  274. )
  275. bsz, q_len, _ = hidden_states.size()
  276. query_states = self.q_proj(hidden_states)
  277. key_states = self.k_proj(hidden_states)
  278. value_states = self.v_proj(hidden_states)
  279. # Flash attention requires the input to have the shape
  280. # batch_size x seq_length x head_dim x hidden_dim
  281. # therefore we just need to keep the original shape
  282. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  283. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  284. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  285. cos, sin = position_embeddings
  286. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  287. if past_key_values is not None:
  288. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  289. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  290. # to be able to avoid many of these transpose/reshape/view.
  291. query_states = query_states.transpose(1, 2)
  292. key_states = key_states.transpose(1, 2)
  293. value_states = value_states.transpose(1, 2)
  294. dropout_rate = self.attention_dropout if self.training else 0.0
  295. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  296. # therefore the input hidden states gets silently casted in float32. Hence, we need
  297. # cast them back in the correct dtype just to be sure everything works as expected.
  298. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  299. # in fp32. (NemotronRMSNorm handles it correctly)
  300. input_dtype = query_states.dtype
  301. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  302. if input_dtype == torch.float32:
  303. if torch.is_autocast_enabled():
  304. target_dtype = torch.get_autocast_dtype(device_type)
  305. # Handle the case where the model is quantized
  306. elif hasattr(self.config, "_is_quantized"):
  307. target_dtype = self.config.dtype
  308. else:
  309. target_dtype = self.q_proj.weight.dtype
  310. logger.warning_once(
  311. f"The input hidden states seems to be silently casted in float32, this might be related to"
  312. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  313. f" {target_dtype}."
  314. )
  315. query_states = query_states.to(target_dtype)
  316. key_states = key_states.to(target_dtype)
  317. value_states = value_states.to(target_dtype)
  318. attn_output = _flash_attention_forward(
  319. query_states,
  320. key_states,
  321. value_states,
  322. attention_mask,
  323. q_len,
  324. position_ids=position_ids,
  325. dropout=dropout_rate,
  326. sliding_window=getattr(self, "sliding_window", None),
  327. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  328. is_causal=self.is_causal,
  329. )
  330. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  331. attn_output = self.o_proj(attn_output)
  332. return attn_output, None
  333. # NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  334. # TODO cyril: modular
  335. class NemotronSdpaAttention(NemotronAttention):
  336. """
  337. Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  338. `NemotronAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  339. SDPA API.
  340. """
  341. def forward(
  342. self,
  343. hidden_states: torch.Tensor,
  344. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  345. attention_mask: torch.Tensor | None = None,
  346. position_ids: torch.LongTensor | None = None,
  347. past_key_values: Cache | None = None,
  348. output_attentions: bool = False,
  349. use_cache: bool = False,
  350. **kwargs,
  351. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  352. if output_attentions:
  353. logger.warning_once(
  354. f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
  355. "be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
  356. )
  357. bsz, q_len, _ = hidden_states.size()
  358. query_states = self.q_proj(hidden_states)
  359. key_states = self.k_proj(hidden_states)
  360. value_states = self.v_proj(hidden_states)
  361. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  362. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  363. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  364. cos, sin = position_embeddings
  365. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  366. if past_key_values is not None:
  367. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  368. key_states = repeat_kv(key_states, self.num_key_value_groups)
  369. value_states = repeat_kv(value_states, self.num_key_value_groups)
  370. causal_mask = attention_mask
  371. if attention_mask is not None:
  372. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  373. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  374. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  375. is_causal = causal_mask is None and q_len > 1
  376. attn_output = torch.nn.functional.scaled_dot_product_attention(
  377. query_states,
  378. key_states,
  379. value_states,
  380. attn_mask=causal_mask,
  381. dropout_p=self.attention_dropout if self.training else 0.0,
  382. is_causal=is_causal,
  383. )
  384. attn_output = attn_output.transpose(1, 2).contiguous()
  385. attn_output = attn_output.view(bsz, q_len, -1)
  386. attn_output = self.o_proj(attn_output)
  387. return attn_output, None
  388. NEMOTRON_ATTENTION_CLASSES = {
  389. "eager": NemotronAttention,
  390. "flash_attention_2": NemotronFlashAttention2,
  391. "sdpa": NemotronSdpaAttention,
  392. }
  393. # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  394. # no longer copied after attention refactors
  395. class NemotronDecoderLayer(GradientCheckpointingLayer):
  396. # Ignore copy
  397. def __init__(self, config: NemotronConfig, layer_idx: int):
  398. super().__init__()
  399. self.hidden_size = config.hidden_size
  400. self.self_attn = NEMOTRON_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  401. self.mlp = NemotronMLP(config)
  402. self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  403. self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  404. def forward(
  405. self,
  406. hidden_states: torch.Tensor,
  407. attention_mask: torch.Tensor | None = None,
  408. position_ids: torch.LongTensor | None = None,
  409. past_key_values: Cache | None = None,
  410. use_cache: bool | None = False,
  411. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  412. **kwargs,
  413. ) -> torch.Tensor:
  414. residual = hidden_states
  415. hidden_states = self.input_layernorm(hidden_states)
  416. # Self Attention
  417. hidden_states, _ = self.self_attn(
  418. hidden_states=hidden_states,
  419. attention_mask=attention_mask,
  420. position_ids=position_ids,
  421. past_key_values=past_key_values,
  422. use_cache=use_cache,
  423. position_embeddings=position_embeddings,
  424. )
  425. hidden_states = residual + hidden_states
  426. # Fully Connected
  427. residual = hidden_states
  428. hidden_states = self.post_attention_layernorm(hidden_states)
  429. hidden_states = self.mlp(hidden_states)
  430. hidden_states = residual + hidden_states
  431. return hidden_states
  432. @auto_docstring
  433. class NemotronPreTrainedModel(PreTrainedModel):
  434. config: NemotronConfig
  435. base_model_prefix = "model"
  436. supports_gradient_checkpointing = True
  437. _no_split_modules = ["NemotronDecoderLayer"]
  438. _skip_keys_device_placement = ["past_key_values"]
  439. _supports_flash_attn = True
  440. _supports_sdpa = True
  441. _can_compile_fullgraph = True
  442. _can_record_outputs = {
  443. "hidden_states": NemotronDecoderLayer,
  444. "attentions": NemotronAttention,
  445. }
  446. @torch.no_grad()
  447. def _init_weights(self, module):
  448. super()._init_weights(module)
  449. if isinstance(module, NemotronLayerNorm1P):
  450. init.ones_(module.weight)
  451. init.zeros_(module.bias)
  452. @auto_docstring
  453. class NemotronModel(NemotronPreTrainedModel):
  454. """
  455. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`NemotronDecoderLayer`]
  456. Args:
  457. config: NemotronConfig
  458. """
  459. def __init__(self, config: NemotronConfig):
  460. super().__init__(config)
  461. self.padding_idx = config.pad_token_id
  462. self.vocab_size = config.vocab_size
  463. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  464. self.layers = nn.ModuleList(
  465. [NemotronDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  466. )
  467. self.norm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps)
  468. self.rotary_emb = NemotronRotaryEmbedding(config=config)
  469. self.gradient_checkpointing = False
  470. # Initialize weights and apply final processing
  471. self.post_init()
  472. @merge_with_config_defaults
  473. @capture_outputs
  474. @auto_docstring
  475. def forward(
  476. self,
  477. input_ids: torch.LongTensor | None = None,
  478. attention_mask: torch.Tensor | None = None,
  479. position_ids: torch.LongTensor | None = None,
  480. past_key_values: Cache | None = None,
  481. inputs_embeds: torch.FloatTensor | None = None,
  482. use_cache: bool | None = None,
  483. **kwargs: Unpack[TransformersKwargs],
  484. ) -> BaseModelOutputWithPast:
  485. if (input_ids is None) ^ (inputs_embeds is not None):
  486. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  487. if use_cache and past_key_values is None:
  488. past_key_values = DynamicCache(config=self.config)
  489. if inputs_embeds is None:
  490. inputs_embeds = self.embed_tokens(input_ids)
  491. if position_ids is None:
  492. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  493. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  494. position_ids = position_ids.unsqueeze(0)
  495. causal_mask = create_causal_mask(
  496. config=self.config,
  497. inputs_embeds=inputs_embeds,
  498. attention_mask=attention_mask,
  499. past_key_values=past_key_values,
  500. position_ids=position_ids,
  501. )
  502. hidden_states = inputs_embeds
  503. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  504. for decoder_layer in self.layers:
  505. hidden_states = decoder_layer(
  506. hidden_states,
  507. attention_mask=causal_mask,
  508. position_ids=position_ids,
  509. past_key_values=past_key_values,
  510. use_cache=use_cache,
  511. position_embeddings=position_embeddings,
  512. **kwargs,
  513. )
  514. hidden_states = self.norm(hidden_states)
  515. return BaseModelOutputWithPast(
  516. last_hidden_state=hidden_states,
  517. past_key_values=past_key_values,
  518. )
  519. # TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
  520. class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin):
  521. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  522. def __init__(self, config):
  523. super().__init__(config)
  524. self.model = NemotronModel(config)
  525. self.vocab_size = config.vocab_size
  526. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  527. # Initialize weights and apply final processing
  528. self.post_init()
  529. @can_return_tuple
  530. @auto_docstring
  531. def forward(
  532. self,
  533. input_ids: torch.LongTensor | None = None,
  534. attention_mask: torch.Tensor | None = None,
  535. position_ids: torch.LongTensor | None = None,
  536. past_key_values: Cache | None = None,
  537. inputs_embeds: torch.FloatTensor | None = None,
  538. labels: torch.LongTensor | None = None,
  539. use_cache: bool | None = None,
  540. logits_to_keep: int | torch.Tensor = 0,
  541. **kwargs: Unpack[TransformersKwargs],
  542. ) -> CausalLMOutputWithPast:
  543. r"""
  544. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  545. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  546. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  547. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  548. Example:
  549. ```python
  550. >>> from transformers import AutoTokenizer, NemotronForCausalLM
  551. >>> model = NemotronForCausalLM.from_pretrained("thhaus/nemotron3-8b")
  552. >>> tokenizer = AutoTokenizer.from_pretrained("thhaus/nemotron3-8b")
  553. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  554. >>> inputs = tokenizer(prompt, return_tensors="pt")
  555. >>> # Generate
  556. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  557. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  558. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  559. ```"""
  560. outputs: BaseModelOutputWithPast = self.model(
  561. input_ids=input_ids,
  562. attention_mask=attention_mask,
  563. position_ids=position_ids,
  564. past_key_values=past_key_values,
  565. inputs_embeds=inputs_embeds,
  566. use_cache=use_cache,
  567. **kwargs,
  568. )
  569. hidden_states = outputs.last_hidden_state
  570. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  571. logits = self.lm_head(hidden_states[:, slice_indices, :])
  572. loss = None
  573. if labels is not None:
  574. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  575. return CausalLMOutputWithPast(
  576. loss=loss,
  577. logits=logits,
  578. past_key_values=outputs.past_key_values,
  579. hidden_states=outputs.hidden_states,
  580. attentions=outputs.attentions,
  581. )
  582. class NemotronForSequenceClassification(GenericForSequenceClassification, NemotronPreTrainedModel): ...
  583. class NemotronForQuestionAnswering(GenericForQuestionAnswering, NemotronPreTrainedModel):
  584. base_model_prefix = "transformer"
  585. class NemotronForTokenClassification(GenericForTokenClassification, NemotronPreTrainedModel): ...
  586. __all__ = [
  587. "NemotronForQuestionAnswering",
  588. "NemotronForCausalLM",
  589. "NemotronModel",
  590. "NemotronPreTrainedModel",
  591. "NemotronForSequenceClassification",
  592. "NemotronForTokenClassification",
  593. ]