modeling_granite.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/granite/modular_granite.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_granite.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  29. from ...masking_utils import create_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_granite import GraniteConfig
  39. def rotate_half(x):
  40. """Rotates half the hidden dims of the input."""
  41. x1 = x[..., : x.shape[-1] // 2]
  42. x2 = x[..., x.shape[-1] // 2 :]
  43. return torch.cat((-x2, x1), dim=-1)
  44. @use_kernel_func_from_hub("rotary_pos_emb")
  45. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  46. """Applies Rotary Position Embedding to the query and key tensors.
  47. Args:
  48. q (`torch.Tensor`): The query tensor.
  49. k (`torch.Tensor`): The key tensor.
  50. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  51. sin (`torch.Tensor`): The sine part of the rotary embedding.
  52. unsqueeze_dim (`int`, *optional*, defaults to 1):
  53. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  54. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  55. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  56. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  57. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  58. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  59. Returns:
  60. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  61. """
  62. cos = cos.unsqueeze(unsqueeze_dim)
  63. sin = sin.unsqueeze(unsqueeze_dim)
  64. q_embed = (q * cos) + (rotate_half(q) * sin)
  65. k_embed = (k * cos) + (rotate_half(k) * sin)
  66. return q_embed, k_embed
  67. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  68. """
  69. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  70. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  71. """
  72. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  73. if n_rep == 1:
  74. return hidden_states
  75. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  76. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  77. def eager_attention_forward(
  78. module: nn.Module,
  79. query: torch.Tensor,
  80. key: torch.Tensor,
  81. value: torch.Tensor,
  82. attention_mask: torch.Tensor | None,
  83. scaling: float,
  84. dropout: float = 0.0,
  85. **kwargs: Unpack[TransformersKwargs],
  86. ):
  87. key_states = repeat_kv(key, module.num_key_value_groups)
  88. value_states = repeat_kv(value, module.num_key_value_groups)
  89. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  90. if attention_mask is not None:
  91. attn_weights = attn_weights + attention_mask
  92. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  93. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  94. attn_output = torch.matmul(attn_weights, value_states)
  95. attn_output = attn_output.transpose(1, 2).contiguous()
  96. return attn_output, attn_weights
  97. @use_kernelized_func(apply_rotary_pos_emb)
  98. class GraniteAttention(nn.Module):
  99. """Multi-headed attention from 'Attention Is All You Need' paper"""
  100. def __init__(self, config: GraniteConfig, layer_idx: int | None = None):
  101. super().__init__()
  102. self.config = config
  103. self.layer_idx = layer_idx
  104. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  105. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  106. self.scaling = config.attention_multiplier
  107. self.attention_dropout = config.attention_dropout
  108. self.is_causal = True
  109. self.q_proj = nn.Linear(
  110. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  111. )
  112. self.k_proj = nn.Linear(
  113. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  114. )
  115. self.v_proj = nn.Linear(
  116. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  117. )
  118. self.o_proj = nn.Linear(
  119. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  120. )
  121. def forward(
  122. self,
  123. hidden_states: torch.Tensor,
  124. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  125. attention_mask: torch.Tensor | None = None,
  126. past_key_values: Cache | None = None,
  127. **kwargs: Unpack[TransformersKwargs],
  128. ) -> tuple[torch.Tensor, torch.Tensor]:
  129. input_shape = hidden_states.shape[:-1]
  130. hidden_shape = (*input_shape, -1, self.head_dim)
  131. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  132. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  133. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  134. cos, sin = position_embeddings
  135. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  136. if past_key_values is not None:
  137. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  138. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  139. self.config._attn_implementation, eager_attention_forward
  140. )
  141. attn_output, attn_weights = attention_interface(
  142. self,
  143. query_states,
  144. key_states,
  145. value_states,
  146. attention_mask,
  147. dropout=0.0 if not self.training else self.attention_dropout,
  148. scaling=self.scaling,
  149. **kwargs,
  150. )
  151. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  152. attn_output = self.o_proj(attn_output)
  153. return attn_output, attn_weights
  154. @use_kernel_forward_from_hub("RMSNorm")
  155. class GraniteRMSNorm(nn.Module):
  156. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  157. """
  158. GraniteRMSNorm is equivalent to T5LayerNorm
  159. """
  160. super().__init__()
  161. self.weight = nn.Parameter(torch.ones(hidden_size))
  162. self.variance_epsilon = eps
  163. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  164. input_dtype = hidden_states.dtype
  165. hidden_states = hidden_states.to(torch.float32)
  166. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  167. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  168. return self.weight * hidden_states.to(input_dtype)
  169. def extra_repr(self):
  170. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  171. class GraniteMLP(nn.Module):
  172. def __init__(self, config):
  173. super().__init__()
  174. self.config = config
  175. self.hidden_size = config.hidden_size
  176. self.intermediate_size = config.intermediate_size
  177. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  178. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  179. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  180. self.act_fn = ACT2FN[config.hidden_act]
  181. def forward(self, x):
  182. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  183. return down_proj
  184. class GraniteDecoderLayer(GradientCheckpointingLayer):
  185. def __init__(self, config: GraniteConfig, layer_idx: int):
  186. super().__init__()
  187. self.hidden_size = config.hidden_size
  188. self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx)
  189. self.mlp = GraniteMLP(config)
  190. self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  192. self.residual_multiplier = config.residual_multiplier
  193. def forward(
  194. self,
  195. hidden_states: torch.Tensor,
  196. attention_mask: torch.Tensor | None = None,
  197. position_ids: torch.LongTensor | None = None,
  198. past_key_values: Cache | None = None,
  199. use_cache: bool | None = False,
  200. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  201. **kwargs: Unpack[TransformersKwargs],
  202. ) -> torch.Tensor:
  203. """
  204. Args:
  205. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  206. attention_mask (`torch.FloatTensor`, *optional*):
  207. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  208. query_sequence_length, key_sequence_length)` if default attention is used.
  209. output_attentions (`bool`, *optional*):
  210. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  211. returned tensors for more detail.
  212. use_cache (`bool`, *optional*):
  213. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  214. (see `past_key_values`).
  215. past_key_values (`Cache`, *optional*): cached past key and value projection states
  216. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  217. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  218. with `head_dim` being the embedding dimension of each attention head.
  219. kwargs (`dict`, *optional*):
  220. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  221. into the model
  222. """
  223. residual = hidden_states
  224. hidden_states = self.input_layernorm(hidden_states)
  225. hidden_states, _ = self.self_attn(
  226. hidden_states=hidden_states,
  227. attention_mask=attention_mask,
  228. position_ids=position_ids,
  229. past_key_values=past_key_values,
  230. use_cache=use_cache,
  231. position_embeddings=position_embeddings,
  232. **kwargs,
  233. )
  234. hidden_states = residual + hidden_states * self.residual_multiplier
  235. residual = hidden_states
  236. hidden_states = self.post_attention_layernorm(hidden_states)
  237. hidden_states = self.mlp(hidden_states)
  238. hidden_states = residual + hidden_states * self.residual_multiplier
  239. return hidden_states
  240. @auto_docstring
  241. class GranitePreTrainedModel(PreTrainedModel):
  242. config: GraniteConfig
  243. base_model_prefix = "model"
  244. supports_gradient_checkpointing = True
  245. _no_split_modules = ["GraniteDecoderLayer"]
  246. _skip_keys_device_placement = ["past_key_values"]
  247. _supports_flash_attn = True
  248. _supports_sdpa = True
  249. _supports_flex_attn = True
  250. _can_compile_fullgraph = True
  251. _supports_attention_backend = True
  252. _can_record_outputs = {
  253. "hidden_states": GraniteDecoderLayer,
  254. "attentions": GraniteAttention,
  255. }
  256. class GraniteRotaryEmbedding(nn.Module):
  257. inv_freq: torch.Tensor # fix linting for `register_buffer`
  258. def __init__(self, config: GraniteConfig, device=None):
  259. super().__init__()
  260. self.max_seq_len_cached = config.max_position_embeddings
  261. self.original_max_seq_len = config.max_position_embeddings
  262. self.config = config
  263. self.rope_type = self.config.rope_parameters["rope_type"]
  264. rope_init_fn: Callable = self.compute_default_rope_parameters
  265. if self.rope_type != "default":
  266. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  267. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  268. self.register_buffer("inv_freq", inv_freq, persistent=False)
  269. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  270. @staticmethod
  271. def compute_default_rope_parameters(
  272. config: GraniteConfig | None = None,
  273. device: Optional["torch.device"] = None,
  274. seq_len: int | None = None,
  275. ) -> tuple["torch.Tensor", float]:
  276. """
  277. Computes the inverse frequencies according to the original RoPE implementation
  278. Args:
  279. config ([`~transformers.PreTrainedConfig`]):
  280. The model configuration.
  281. device (`torch.device`):
  282. The device to use for initialization of the inverse frequencies.
  283. seq_len (`int`, *optional*):
  284. The current sequence length. Unused for this type of RoPE.
  285. Returns:
  286. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  287. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  288. """
  289. base = config.rope_parameters["rope_theta"]
  290. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  291. attention_factor = 1.0 # Unused in this type of RoPE
  292. # Compute the inverse frequencies
  293. inv_freq = 1.0 / (
  294. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  295. )
  296. return inv_freq, attention_factor
  297. @torch.no_grad()
  298. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  299. def forward(self, x, position_ids):
  300. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  301. position_ids_expanded = position_ids[:, None, :].float()
  302. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  303. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  304. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  305. emb = torch.cat((freqs, freqs), dim=-1)
  306. cos = emb.cos() * self.attention_scaling
  307. sin = emb.sin() * self.attention_scaling
  308. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  309. @auto_docstring
  310. class GraniteModel(GranitePreTrainedModel):
  311. def __init__(self, config: GraniteConfig):
  312. super().__init__(config)
  313. self.padding_idx = config.pad_token_id
  314. self.vocab_size = config.vocab_size
  315. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  316. self.layers = nn.ModuleList(
  317. [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  318. )
  319. self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  320. self.rotary_emb = GraniteRotaryEmbedding(config=config)
  321. self.gradient_checkpointing = False
  322. self.embedding_multiplier = config.embedding_multiplier
  323. # Initialize weights and apply final processing
  324. self.post_init()
  325. @merge_with_config_defaults
  326. @capture_outputs
  327. @auto_docstring
  328. def forward(
  329. self,
  330. input_ids: torch.LongTensor | None = None,
  331. attention_mask: torch.Tensor | None = None,
  332. position_ids: torch.LongTensor | None = None,
  333. past_key_values: Cache | None = None,
  334. inputs_embeds: torch.FloatTensor | None = None,
  335. use_cache: bool | None = None,
  336. **kwargs: Unpack[TransformersKwargs],
  337. ) -> BaseModelOutputWithPast:
  338. if (input_ids is None) ^ (inputs_embeds is not None):
  339. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  340. if inputs_embeds is None:
  341. inputs_embeds = self.embed_tokens(input_ids)
  342. inputs_embeds = inputs_embeds * self.embedding_multiplier
  343. if use_cache and past_key_values is None:
  344. past_key_values = DynamicCache(config=self.config)
  345. if position_ids is None:
  346. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  347. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  348. position_ids = position_ids.unsqueeze(0)
  349. causal_mask = create_causal_mask(
  350. config=self.config,
  351. inputs_embeds=inputs_embeds,
  352. attention_mask=attention_mask,
  353. past_key_values=past_key_values,
  354. position_ids=position_ids,
  355. )
  356. hidden_states = inputs_embeds
  357. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  358. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  359. hidden_states = decoder_layer(
  360. hidden_states,
  361. attention_mask=causal_mask,
  362. position_ids=position_ids,
  363. past_key_values=past_key_values,
  364. use_cache=use_cache,
  365. position_embeddings=position_embeddings,
  366. **kwargs,
  367. )
  368. hidden_states = self.norm(hidden_states)
  369. return BaseModelOutputWithPast(
  370. last_hidden_state=hidden_states,
  371. past_key_values=past_key_values,
  372. )
  373. @auto_docstring
  374. class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
  375. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  376. _tp_plan = {"lm_head": "colwise_gather_output"}
  377. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  378. def __init__(self, config):
  379. super().__init__(config)
  380. self.model = GraniteModel(config)
  381. self.vocab_size = config.vocab_size
  382. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  383. # Initialize weights and apply final processing
  384. self.post_init()
  385. @can_return_tuple
  386. @auto_docstring
  387. def forward(
  388. self,
  389. input_ids: torch.LongTensor | None = None,
  390. attention_mask: torch.Tensor | None = None,
  391. position_ids: torch.LongTensor | None = None,
  392. past_key_values: Cache | None = None,
  393. inputs_embeds: torch.FloatTensor | None = None,
  394. labels: torch.LongTensor | None = None,
  395. use_cache: bool | None = None,
  396. logits_to_keep: int | torch.Tensor = 0,
  397. **kwargs: Unpack[TransformersKwargs],
  398. ) -> CausalLMOutputWithPast:
  399. r"""
  400. Example:
  401. ```python
  402. >>> from transformers import AutoTokenizer, GraniteForCausalLM
  403. >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf")
  404. >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf")
  405. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  406. >>> inputs = tokenizer(prompt, return_tensors="pt")
  407. >>> # Generate
  408. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  409. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  410. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  411. ```"""
  412. outputs: BaseModelOutputWithPast = self.model(
  413. input_ids=input_ids,
  414. attention_mask=attention_mask,
  415. position_ids=position_ids,
  416. past_key_values=past_key_values,
  417. inputs_embeds=inputs_embeds,
  418. use_cache=use_cache,
  419. **kwargs,
  420. )
  421. hidden_states = outputs.last_hidden_state
  422. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  423. logits = self.lm_head(hidden_states[:, slice_indices, :])
  424. logits = logits / self.config.logits_scaling # main diff with Llama
  425. loss = None
  426. if labels is not None:
  427. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  428. return CausalLMOutputWithPast(
  429. loss=loss,
  430. logits=logits,
  431. past_key_values=outputs.past_key_values,
  432. hidden_states=outputs.hidden_states,
  433. attentions=outputs.attentions,
  434. )
  435. __all__ = ["GraniteForCausalLM", "GraniteModel", "GranitePreTrainedModel"]