modeling_lfm2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/lfm2/modular_lfm2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_lfm2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.nn.functional as F
  24. from torch import nn
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  31. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  35. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  36. from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_lfm2 import Lfm2Config
  39. if is_causal_conv1d_available():
  40. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  41. else:
  42. causal_conv1d_fn, causal_conv1d_update = None, None
  43. @use_kernel_forward_from_hub("RMSNorm")
  44. class Lfm2RMSNorm(nn.Module):
  45. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  46. """
  47. Lfm2RMSNorm is equivalent to T5LayerNorm
  48. """
  49. super().__init__()
  50. self.weight = nn.Parameter(torch.ones(hidden_size))
  51. self.variance_epsilon = eps
  52. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  53. input_dtype = hidden_states.dtype
  54. hidden_states = hidden_states.to(torch.float32)
  55. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  56. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  57. return self.weight * hidden_states.to(input_dtype)
  58. def extra_repr(self):
  59. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  60. class Lfm2RotaryEmbedding(nn.Module):
  61. inv_freq: torch.Tensor # fix linting for `register_buffer`
  62. def __init__(self, config: Lfm2Config, device=None):
  63. super().__init__()
  64. self.max_seq_len_cached = config.max_position_embeddings
  65. self.original_max_seq_len = config.max_position_embeddings
  66. self.config = config
  67. self.rope_type = self.config.rope_parameters["rope_type"]
  68. rope_init_fn: Callable = self.compute_default_rope_parameters
  69. if self.rope_type != "default":
  70. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  71. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  72. self.register_buffer("inv_freq", inv_freq, persistent=False)
  73. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  74. @staticmethod
  75. def compute_default_rope_parameters(
  76. config: Lfm2Config | None = None,
  77. device: Optional["torch.device"] = None,
  78. seq_len: int | None = None,
  79. ) -> tuple["torch.Tensor", float]:
  80. """
  81. Computes the inverse frequencies according to the original RoPE implementation
  82. Args:
  83. config ([`~transformers.PreTrainedConfig`]):
  84. The model configuration.
  85. device (`torch.device`):
  86. The device to use for initialization of the inverse frequencies.
  87. seq_len (`int`, *optional*):
  88. The current sequence length. Unused for this type of RoPE.
  89. Returns:
  90. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  91. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  92. """
  93. base = config.rope_parameters["rope_theta"]
  94. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  95. attention_factor = 1.0 # Unused in this type of RoPE
  96. # Compute the inverse frequencies
  97. inv_freq = 1.0 / (
  98. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  99. )
  100. return inv_freq, attention_factor
  101. @torch.no_grad()
  102. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  103. def forward(self, x, position_ids):
  104. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  105. position_ids_expanded = position_ids[:, None, :].float()
  106. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  107. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  108. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  109. emb = torch.cat((freqs, freqs), dim=-1)
  110. cos = emb.cos() * self.attention_scaling
  111. sin = emb.sin() * self.attention_scaling
  112. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  113. class Lfm2MLP(nn.Module):
  114. def __init__(self, config: Lfm2Config):
  115. super().__init__()
  116. intermediate_size = config.intermediate_size
  117. if config.block_auto_adjust_ff_dim:
  118. intermediate_size = int(2 * intermediate_size / 3)
  119. # custom dim factor multiplier
  120. if config.block_ffn_dim_multiplier is not None:
  121. intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
  122. intermediate_size = config.block_multiple_of * (
  123. (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
  124. )
  125. self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  126. self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  127. self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
  128. def forward(self, x):
  129. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  130. def rotate_half(x):
  131. """Rotates half the hidden dims of the input."""
  132. x1 = x[..., : x.shape[-1] // 2]
  133. x2 = x[..., x.shape[-1] // 2 :]
  134. return torch.cat((-x2, x1), dim=-1)
  135. @use_kernel_func_from_hub("rotary_pos_emb")
  136. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  137. """Applies Rotary Position Embedding to the query and key tensors.
  138. Args:
  139. q (`torch.Tensor`): The query tensor.
  140. k (`torch.Tensor`): The key tensor.
  141. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  142. sin (`torch.Tensor`): The sine part of the rotary embedding.
  143. unsqueeze_dim (`int`, *optional*, defaults to 1):
  144. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  145. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  146. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  147. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  148. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  149. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  150. Returns:
  151. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  152. """
  153. cos = cos.unsqueeze(unsqueeze_dim)
  154. sin = sin.unsqueeze(unsqueeze_dim)
  155. q_embed = (q * cos) + (rotate_half(q) * sin)
  156. k_embed = (k * cos) + (rotate_half(k) * sin)
  157. return q_embed, k_embed
  158. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  159. """
  160. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  161. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  162. """
  163. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  164. if n_rep == 1:
  165. return hidden_states
  166. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  167. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  168. def eager_attention_forward(
  169. module: nn.Module,
  170. query: torch.Tensor,
  171. key: torch.Tensor,
  172. value: torch.Tensor,
  173. attention_mask: torch.Tensor | None,
  174. scaling: float,
  175. dropout: float = 0.0,
  176. **kwargs: Unpack[TransformersKwargs],
  177. ):
  178. key_states = repeat_kv(key, module.num_key_value_groups)
  179. value_states = repeat_kv(value, module.num_key_value_groups)
  180. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  181. if attention_mask is not None:
  182. attn_weights = attn_weights + attention_mask
  183. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  184. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  185. attn_output = torch.matmul(attn_weights, value_states)
  186. attn_output = attn_output.transpose(1, 2).contiguous()
  187. return attn_output, attn_weights
  188. @use_kernelized_func(apply_rotary_pos_emb)
  189. class Lfm2Attention(nn.Module):
  190. """Multi-headed attention from 'Attention Is All You Need' paper"""
  191. def __init__(self, config: Lfm2Config, layer_idx: int):
  192. super().__init__()
  193. self.config = config
  194. self.layer_idx = layer_idx
  195. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  196. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  197. self.scaling = self.head_dim**-0.5
  198. self.is_causal = True
  199. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  200. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  201. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  202. self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  203. self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  204. self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  209. attention_mask: torch.Tensor | None,
  210. past_key_values: Cache | None = None,
  211. **kwargs,
  212. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  213. input_shape = hidden_states.shape[:-1]
  214. hidden_shape = (*input_shape, -1, self.head_dim)
  215. query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  216. key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  217. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  218. cos, sin = position_embeddings
  219. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  220. if past_key_values is not None:
  221. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  222. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  223. self.config._attn_implementation, eager_attention_forward
  224. )
  225. attn_output, attn_weights = attention_interface(
  226. self,
  227. query_states,
  228. key_states,
  229. value_states,
  230. attention_mask,
  231. dropout=0.0,
  232. scaling=self.scaling,
  233. **kwargs,
  234. )
  235. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  236. output = self.out_proj(attn_output)
  237. return output, attn_weights
  238. def apply_mask_to_padding_states(hidden_states, attention_mask):
  239. """
  240. Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
  241. """
  242. # NOTE: attention mask is a 2D boolean tensor
  243. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  244. dtype = hidden_states.dtype
  245. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  246. return hidden_states
  247. kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
  248. is_fast_path_available = all(kernel_modules)
  249. class Lfm2ShortConv(nn.Module):
  250. def __init__(
  251. self,
  252. config: Lfm2Config,
  253. layer_idx: int,
  254. ):
  255. super().__init__()
  256. self.config = config
  257. self.layer_idx = layer_idx
  258. self.L_cache = config.conv_L_cache
  259. self.bias = config.conv_bias
  260. self.conv = nn.Conv1d(
  261. in_channels=config.hidden_size,
  262. out_channels=config.hidden_size,
  263. kernel_size=self.L_cache,
  264. groups=config.hidden_size,
  265. bias=self.bias,
  266. padding=self.L_cache - 1,
  267. )
  268. self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
  269. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
  270. def cuda_kernels_forward(
  271. self,
  272. x: torch.Tensor,
  273. past_key_values: Cache | None = None,
  274. attention_mask: torch.Tensor | None = None,
  275. ):
  276. x = apply_mask_to_padding_states(x, attention_mask)
  277. BCx = self.in_proj(x).transpose(-1, -2)
  278. B, C, x = BCx.chunk(3, dim=-2)
  279. Bx = B * x
  280. conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
  281. if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx):
  282. conv_out = causal_conv1d_update(
  283. Bx.squeeze(-1),
  284. past_key_values.layers[self.layer_idx].conv_states,
  285. conv_weights,
  286. self.conv.bias,
  287. None,
  288. )
  289. conv_out = conv_out.unsqueeze(-1)
  290. else:
  291. if past_key_values is not None:
  292. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  293. conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx)
  294. conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
  295. y = C * conv_out
  296. y = self.out_proj(y.transpose(-1, -2).contiguous())
  297. return y
  298. def slow_forward(
  299. self,
  300. x: torch.Tensor,
  301. past_key_values: Cache | None = None,
  302. attention_mask: torch.Tensor | None = None,
  303. ):
  304. seqlen = x.shape[1]
  305. x = apply_mask_to_padding_states(x, attention_mask)
  306. BCx = self.in_proj(x).transpose(-1, -2)
  307. B, C, x = BCx.chunk(3, dim=-2)
  308. Bx = B * x
  309. if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx):
  310. conv_state = past_key_values.update_conv_state(Bx, self.layer_idx)
  311. conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
  312. if self.bias:
  313. conv_out += self.conv.bias
  314. conv_out = conv_out.unsqueeze(-1)
  315. else:
  316. if past_key_values is not None:
  317. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  318. conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx)
  319. conv_out = self.conv(Bx)[..., :seqlen]
  320. y = C * conv_out
  321. y = y.transpose(-1, -2).contiguous()
  322. y = self.out_proj(y)
  323. return y
  324. def forward(
  325. self,
  326. hidden_states: torch.Tensor,
  327. past_key_values: Cache | None = None,
  328. attention_mask: torch.Tensor | None = None,
  329. ):
  330. if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling():
  331. return self.cuda_kernels_forward(hidden_states, past_key_values, attention_mask)
  332. return self.slow_forward(hidden_states, past_key_values, attention_mask)
  333. class Lfm2DecoderLayer(GradientCheckpointingLayer):
  334. def __init__(self, config: Lfm2Config, layer_idx: int):
  335. super().__init__()
  336. self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
  337. if self.is_attention_layer:
  338. self.self_attn = Lfm2Attention(config, layer_idx)
  339. else:
  340. self.conv = Lfm2ShortConv(config, layer_idx)
  341. self.feed_forward = Lfm2MLP(config)
  342. self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  343. self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  344. def forward(
  345. self,
  346. hidden_states: torch.Tensor,
  347. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  348. attention_mask: torch.Tensor | None = None,
  349. position_ids: torch.LongTensor | None = None,
  350. past_key_values: Cache | None = None,
  351. **kwargs,
  352. ) -> torch.Tensor:
  353. residual = hidden_states
  354. if self.is_attention_layer:
  355. hidden_states, _ = self.self_attn(
  356. hidden_states=self.operator_norm(hidden_states),
  357. position_embeddings=position_embeddings,
  358. attention_mask=attention_mask,
  359. position_ids=position_ids,
  360. past_key_values=past_key_values,
  361. **kwargs,
  362. )
  363. else:
  364. hidden_states = self.conv(
  365. hidden_states=self.operator_norm(hidden_states),
  366. past_key_values=past_key_values,
  367. attention_mask=attention_mask,
  368. )
  369. hidden_states = hidden_states + residual
  370. hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
  371. return hidden_states
  372. @auto_docstring
  373. class Lfm2PreTrainedModel(PreTrainedModel):
  374. config: Lfm2Config
  375. base_model_prefix = "model"
  376. supports_gradient_checkpointing = True
  377. _no_split_modules = ["Lfm2DecoderLayer"]
  378. _skip_keys_device_placement = ["past_key_values"]
  379. _supports_flash_attn = True
  380. _supports_sdpa = True
  381. _supports_flex_attn = True
  382. _can_compile_fullgraph = False
  383. _supports_attention_backend = True
  384. _can_record_outputs = {
  385. "hidden_states": Lfm2DecoderLayer,
  386. "attentions": Lfm2Attention,
  387. }
  388. @auto_docstring
  389. class Lfm2Model(Lfm2PreTrainedModel):
  390. def __init__(self, config: Lfm2Config):
  391. super().__init__(config)
  392. self.padding_idx = config.pad_token_id
  393. self.vocab_size = config.vocab_size
  394. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  395. self.layers = nn.ModuleList(
  396. [Lfm2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  397. )
  398. self.rotary_emb = Lfm2RotaryEmbedding(config=config)
  399. self.gradient_checkpointing = False
  400. self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  401. # Initialize weights and apply final processing
  402. self.post_init()
  403. @merge_with_config_defaults
  404. @capture_outputs
  405. @auto_docstring
  406. def forward(
  407. self,
  408. input_ids: torch.LongTensor | None = None,
  409. attention_mask: torch.Tensor | None = None,
  410. position_ids: torch.LongTensor | None = None,
  411. past_key_values: Cache | None = None,
  412. inputs_embeds: torch.FloatTensor | None = None,
  413. use_cache: bool | None = None,
  414. **kwargs: Unpack[TransformersKwargs],
  415. ) -> BaseModelOutputWithPast:
  416. if (input_ids is None) ^ (inputs_embeds is not None):
  417. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  418. if inputs_embeds is None:
  419. inputs_embeds = self.embed_tokens(input_ids)
  420. if use_cache and past_key_values is None:
  421. past_key_values = DynamicCache(config=self.config)
  422. if position_ids is None:
  423. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  424. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  425. position_ids = position_ids.unsqueeze(0)
  426. causal_mask = create_causal_mask(
  427. config=self.config,
  428. inputs_embeds=inputs_embeds,
  429. attention_mask=attention_mask,
  430. past_key_values=past_key_values,
  431. position_ids=position_ids,
  432. )
  433. # Skip masking for decoding stage. We check shape here to be compile-friendly
  434. linear_attention = attention_mask if inputs_embeds.shape[1] != 1 else None
  435. hidden_states = inputs_embeds
  436. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  437. # decoder layers
  438. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  439. layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention
  440. hidden_states = decoder_layer(
  441. hidden_states,
  442. attention_mask=layer_mask,
  443. position_embeddings=position_embeddings,
  444. position_ids=position_ids,
  445. past_key_values=past_key_values,
  446. **kwargs,
  447. )
  448. hidden_states = self.embedding_norm(hidden_states)
  449. return BaseModelOutputWithPast(
  450. last_hidden_state=hidden_states,
  451. past_key_values=past_key_values,
  452. )
  453. @auto_docstring
  454. class Lfm2ForCausalLM(Lfm2PreTrainedModel, GenerationMixin):
  455. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  456. _tp_plan = {"lm_head": "colwise_gather_output"}
  457. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  458. def __init__(self, config):
  459. super().__init__(config)
  460. self.model = Lfm2Model(config)
  461. self.vocab_size = config.vocab_size
  462. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  463. # Initialize weights and apply final processing
  464. self.post_init()
  465. @can_return_tuple
  466. @auto_docstring
  467. def forward(
  468. self,
  469. input_ids: torch.LongTensor | None = None,
  470. attention_mask: torch.Tensor | None = None,
  471. position_ids: torch.LongTensor | None = None,
  472. past_key_values: Cache | None = None,
  473. inputs_embeds: torch.FloatTensor | None = None,
  474. labels: torch.LongTensor | None = None,
  475. use_cache: bool | None = None,
  476. logits_to_keep: int | torch.Tensor = 0,
  477. **kwargs: Unpack[TransformersKwargs],
  478. ) -> CausalLMOutputWithPast:
  479. r"""
  480. Example:
  481. ```python
  482. >>> from transformers import AutoTokenizer, Lfm2ForCausalLM
  483. >>> model = Lfm2ForCausalLM.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
  484. >>> tokenizer = AutoTokenizer.from_pretrained("meta-lfm2/Lfm2-2-7b-hf")
  485. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  486. >>> inputs = tokenizer(prompt, return_tensors="pt")
  487. >>> # Generate
  488. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  489. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  490. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  491. ```"""
  492. outputs: BaseModelOutputWithPast = self.model(
  493. input_ids=input_ids,
  494. attention_mask=attention_mask,
  495. position_ids=position_ids,
  496. past_key_values=past_key_values,
  497. inputs_embeds=inputs_embeds,
  498. use_cache=use_cache,
  499. **kwargs,
  500. )
  501. hidden_states = outputs.last_hidden_state
  502. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  503. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  504. logits = self.lm_head(hidden_states[:, slice_indices, :])
  505. loss = None
  506. if labels is not None:
  507. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  508. return CausalLMOutputWithPast(
  509. loss=loss,
  510. logits=logits,
  511. past_key_values=outputs.past_key_values,
  512. hidden_states=outputs.hidden_states,
  513. attentions=outputs.attentions,
  514. )
  515. __all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]