modeling_stablelm.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. """PyTorch StableLM model."""
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...masking_utils import create_causal_mask
  28. from ...modeling_layers import (
  29. GenericForSequenceClassification,
  30. GenericForTokenClassification,
  31. GradientCheckpointingLayer,
  32. )
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithPast,
  35. CausalLMOutputWithPast,
  36. )
  37. from ...modeling_rope_utils import (
  38. ROPE_INIT_FUNCTIONS,
  39. dynamic_rope_update,
  40. )
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  44. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  45. from ...utils.output_capturing import capture_outputs
  46. from .configuration_stablelm import StableLmConfig
  47. logger = logging.get_logger(__name__)
  48. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->StableLm
  49. class StableLmRotaryEmbedding(nn.Module):
  50. inv_freq: torch.Tensor # fix linting for `register_buffer`
  51. def __init__(self, config: StableLmConfig, device=None):
  52. super().__init__()
  53. self.max_seq_len_cached = config.max_position_embeddings
  54. self.original_max_seq_len = config.max_position_embeddings
  55. self.config = config
  56. self.rope_type = self.config.rope_parameters["rope_type"]
  57. rope_init_fn: Callable = self.compute_default_rope_parameters
  58. if self.rope_type != "default":
  59. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  60. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  61. self.register_buffer("inv_freq", inv_freq, persistent=False)
  62. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  63. @staticmethod
  64. # Ignore copy
  65. def compute_default_rope_parameters(
  66. config: StableLmConfig | None = None,
  67. device: Optional["torch.device"] = None,
  68. seq_len: int | None = None,
  69. ) -> tuple["torch.Tensor", float]:
  70. """
  71. Computes the inverse frequencies according to the original RoPE implementation
  72. Args:
  73. config ([`~transformers.PreTrainedConfig`]):
  74. The model configuration.
  75. device (`torch.device`):
  76. The device to use for initialization of the inverse frequencies.
  77. seq_len (`int`, *optional*):
  78. The current sequence length. Unused for this type of RoPE.
  79. Returns:
  80. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  81. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  82. """
  83. base = config.rope_parameters["rope_theta"]
  84. partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
  85. head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  86. dim = int(head_dim * partial_rotary_factor)
  87. attention_factor = 1.0 # Unused in this type of RoPE
  88. # Compute the inverse frequencies
  89. inv_freq = 1.0 / (
  90. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  91. )
  92. return inv_freq, attention_factor
  93. @torch.no_grad()
  94. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  95. def forward(self, x, position_ids):
  96. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  97. position_ids_expanded = position_ids[:, None, :].float()
  98. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  99. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  100. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  101. emb = torch.cat((freqs, freqs), dim=-1)
  102. cos = emb.cos() * self.attention_scaling
  103. sin = emb.sin() * self.attention_scaling
  104. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  105. # Copied from transformers.models.llama.modeling_llama.rotate_half
  106. def rotate_half(x):
  107. """Rotates half the hidden dims of the input."""
  108. x1 = x[..., : x.shape[-1] // 2]
  109. x2 = x[..., x.shape[-1] // 2 :]
  110. return torch.cat((-x2, x1), dim=-1)
  111. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  112. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  113. """Applies Rotary Position Embedding to the query and key tensors.
  114. Args:
  115. q (`torch.Tensor`): The query tensor.
  116. k (`torch.Tensor`): The key tensor.
  117. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  118. sin (`torch.Tensor`): The sine part of the rotary embedding.
  119. unsqueeze_dim (`int`, *optional*, defaults to 1):
  120. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  121. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  122. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  123. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  124. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  125. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  126. Returns:
  127. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  128. """
  129. cos = cos.unsqueeze(unsqueeze_dim)
  130. sin = sin.unsqueeze(unsqueeze_dim)
  131. q_embed = (q * cos) + (rotate_half(q) * sin)
  132. k_embed = (k * cos) + (rotate_half(k) * sin)
  133. return q_embed, k_embed
  134. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->StableLm
  135. class StableLmMLP(nn.Module):
  136. def __init__(self, config):
  137. super().__init__()
  138. self.config = config
  139. self.hidden_size = config.hidden_size
  140. self.intermediate_size = config.intermediate_size
  141. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  142. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  143. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  144. self.act_fn = ACT2FN[config.hidden_act]
  145. def forward(self, x):
  146. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  147. return down_proj
  148. class StableLmLayerNormPerHead(nn.Module):
  149. def __init__(self, dim, num_heads, eps=1e-5, bias=False):
  150. super().__init__()
  151. self.dim = dim
  152. self.num_heads = num_heads
  153. self.norms = nn.ModuleList([nn.LayerNorm(dim, eps=eps, bias=bias) for _ in range(self.num_heads)])
  154. def forward(self, hidden_states: torch.Tensor):
  155. # Split along the num_heads axis to get per-head inputs
  156. # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, 1, seq_len, head_dim] * num_heads
  157. states_per_heads = torch.split(hidden_states, 1, dim=1)
  158. # Normalize and merge the heads back together
  159. return torch.cat([norm(hidden_states) for norm, hidden_states in zip(self.norms, states_per_heads)], dim=1)
  160. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  161. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  162. """
  163. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  164. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  165. """
  166. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  167. if n_rep == 1:
  168. return hidden_states
  169. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  170. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  171. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  172. def eager_attention_forward(
  173. module: nn.Module,
  174. query: torch.Tensor,
  175. key: torch.Tensor,
  176. value: torch.Tensor,
  177. attention_mask: torch.Tensor | None,
  178. scaling: float,
  179. dropout: float = 0.0,
  180. **kwargs: Unpack[TransformersKwargs],
  181. ):
  182. key_states = repeat_kv(key, module.num_key_value_groups)
  183. value_states = repeat_kv(value, module.num_key_value_groups)
  184. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  185. if attention_mask is not None:
  186. attn_weights = attn_weights + attention_mask
  187. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  188. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  189. attn_output = torch.matmul(attn_weights, value_states)
  190. attn_output = attn_output.transpose(1, 2).contiguous()
  191. return attn_output, attn_weights
  192. class StableLmAttention(nn.Module):
  193. """Multi-headed attention from 'Attention Is All You Need' paper"""
  194. def __init__(self, config: StableLmConfig, layer_idx: int | None = None):
  195. super().__init__()
  196. self.config = config
  197. self.layer_idx = layer_idx
  198. if layer_idx is None:
  199. logger.warning_once(
  200. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  201. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  202. "when creating this class."
  203. )
  204. self.hidden_size = config.hidden_size
  205. self.num_heads = config.num_attention_heads
  206. self.head_dim = self.hidden_size // self.num_heads
  207. self.num_key_value_heads = config.num_key_value_heads
  208. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  209. self.rotary_ndims = int(self.head_dim * config.rope_parameters["partial_rotary_factor"])
  210. self.is_causal = True
  211. self.scaling = self.head_dim**-0.5
  212. if (self.head_dim * self.num_heads) != self.hidden_size:
  213. raise ValueError(
  214. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  215. f" and `num_heads`: {self.num_heads})."
  216. )
  217. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias)
  218. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
  219. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.use_qkv_bias)
  220. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  221. self.qk_layernorm = config.qk_layernorm
  222. if self.qk_layernorm:
  223. self.q_layernorm = StableLmLayerNormPerHead(self.head_dim, self.num_heads, eps=config.layer_norm_eps)
  224. self.k_layernorm = StableLmLayerNormPerHead(
  225. self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
  226. )
  227. self.attention_dropout = config.attention_dropout
  228. def forward(
  229. self,
  230. hidden_states: torch.Tensor,
  231. attention_mask: torch.Tensor | None = None,
  232. position_ids: torch.LongTensor | None = None,
  233. past_key_values: Cache | None = None,
  234. output_attentions: bool = False,
  235. use_cache: bool = False,
  236. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  237. **kwargs,
  238. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  239. bsz, q_len, _ = hidden_states.size()
  240. query_states = self.q_proj(hidden_states)
  241. key_states = self.k_proj(hidden_states)
  242. value_states = self.v_proj(hidden_states)
  243. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  244. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  245. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  246. if self.qk_layernorm:
  247. query_states = self.q_layernorm(query_states)
  248. key_states = self.k_layernorm(key_states)
  249. cos, sin = position_embeddings
  250. query_rot, query_pass = (
  251. query_states[..., : self.rotary_ndims],
  252. query_states[..., self.rotary_ndims :],
  253. )
  254. key_rot, key_pass = (
  255. key_states[..., : self.rotary_ndims],
  256. key_states[..., self.rotary_ndims :],
  257. )
  258. # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
  259. query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
  260. # [batch_size, seq_length, num_heads, head_dim]
  261. query_states = torch.cat((query_rot, query_pass), dim=-1)
  262. key_states = torch.cat((key_rot, key_pass), dim=-1)
  263. if past_key_values is not None:
  264. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  265. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  266. self.config._attn_implementation, eager_attention_forward
  267. )
  268. attn_output, attn_weights = attention_interface(
  269. self,
  270. query_states,
  271. key_states,
  272. value_states,
  273. attention_mask,
  274. dropout=0.0 if not self.training else self.attention_dropout,
  275. scaling=self.scaling,
  276. position_ids=position_ids, # pass `position_ids` for FA2
  277. **kwargs,
  278. )
  279. attn_output = attn_output.reshape(bsz, q_len, -1)
  280. attn_output = self.o_proj(attn_output)
  281. return attn_output, attn_weights
  282. class StableLmDecoderLayer(GradientCheckpointingLayer):
  283. def __init__(self, config: StableLmConfig, layer_idx: int):
  284. super().__init__()
  285. self.use_parallel_residual = config.use_parallel_residual
  286. self.hidden_size = config.hidden_size
  287. self.self_attn = StableLmAttention(config, layer_idx=layer_idx)
  288. self.mlp = StableLmMLP(config)
  289. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  290. self.post_attention_layernorm = None
  291. if not self.use_parallel_residual:
  292. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  293. self.dropout = nn.Dropout(config.hidden_dropout)
  294. def forward(
  295. self,
  296. hidden_states: torch.Tensor,
  297. attention_mask: torch.Tensor | None = None,
  298. position_ids: torch.LongTensor | None = None,
  299. past_key_values: Cache | None = None,
  300. use_cache: bool | None = False,
  301. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  302. **kwargs,
  303. ) -> torch.Tensor:
  304. residual = hidden_states
  305. hidden_states = self.input_layernorm(hidden_states)
  306. # Self Attention
  307. self_attn_output, _ = self.self_attn(
  308. hidden_states=hidden_states,
  309. attention_mask=attention_mask,
  310. position_ids=position_ids,
  311. past_key_values=past_key_values,
  312. use_cache=use_cache,
  313. position_embeddings=position_embeddings,
  314. )
  315. # copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
  316. if self.use_parallel_residual:
  317. # x = x + attn(ln1(x)) + mlp(ln1(x))
  318. # Fully Connected
  319. mlp_output = self.mlp(hidden_states)
  320. mlp_output = self.dropout(mlp_output)
  321. hidden_states = residual + self_attn_output + mlp_output
  322. else:
  323. # x = x + attn(ln1(x))
  324. # x = x + mlp(ln2(x))
  325. residual = residual + self_attn_output
  326. # Fully Connected
  327. mlp_output = self.mlp(self.post_attention_layernorm(residual))
  328. mlp_output = self.dropout(mlp_output)
  329. hidden_states = residual + mlp_output
  330. return hidden_states
  331. @auto_docstring
  332. class StableLmPreTrainedModel(PreTrainedModel):
  333. config: StableLmConfig
  334. base_model_prefix = "model"
  335. supports_gradient_checkpointing = True
  336. _no_split_modules = ["StableLmDecoderLayer"]
  337. _skip_keys_device_placement = "past_key_values"
  338. _supports_flash_attn = True
  339. _supports_sdpa = True
  340. _can_compile_fullgraph = True
  341. _can_record_outputs = {
  342. "hidden_states": StableLmDecoderLayer,
  343. "attentions": StableLmAttention,
  344. }
  345. @auto_docstring
  346. class StableLmModel(StableLmPreTrainedModel):
  347. """
  348. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`StableLmDecoderLayer`]
  349. Args:
  350. config: StableLmConfig
  351. """
  352. def __init__(self, config: StableLmConfig):
  353. super().__init__(config)
  354. self.padding_idx = config.pad_token_id
  355. self.vocab_size = config.vocab_size
  356. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  357. self.layers = nn.ModuleList(
  358. [StableLmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  359. )
  360. self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  361. self._attn_implementation = config._attn_implementation
  362. self.gradient_checkpointing = False
  363. self.rotary_emb = StableLmRotaryEmbedding(config=self.config)
  364. # Initialize weights and apply final processing
  365. self.post_init()
  366. @merge_with_config_defaults
  367. @capture_outputs
  368. @auto_docstring
  369. def forward(
  370. self,
  371. input_ids: torch.LongTensor | None = None,
  372. attention_mask: torch.Tensor | None = None,
  373. position_ids: torch.LongTensor | None = None,
  374. past_key_values: Cache | None = None,
  375. inputs_embeds: torch.FloatTensor | None = None,
  376. use_cache: bool | None = None,
  377. **kwargs: Unpack[TransformersKwargs],
  378. ) -> BaseModelOutputWithPast:
  379. if (input_ids is None) ^ (inputs_embeds is not None):
  380. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  381. if use_cache and past_key_values is None:
  382. past_key_values = DynamicCache(config=self.config)
  383. if inputs_embeds is None:
  384. inputs_embeds = self.embed_tokens(input_ids)
  385. if position_ids is None:
  386. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  387. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  388. position_ids = position_ids.unsqueeze(0)
  389. causal_mask = create_causal_mask(
  390. config=self.config,
  391. inputs_embeds=inputs_embeds,
  392. attention_mask=attention_mask,
  393. past_key_values=past_key_values,
  394. position_ids=position_ids,
  395. )
  396. hidden_states = inputs_embeds
  397. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  398. for decoder_layer in self.layers:
  399. hidden_states = decoder_layer(
  400. hidden_states,
  401. attention_mask=causal_mask,
  402. position_ids=position_ids,
  403. past_key_values=past_key_values,
  404. use_cache=use_cache,
  405. position_embeddings=position_embeddings,
  406. **kwargs,
  407. )
  408. hidden_states = self.norm(hidden_states)
  409. return BaseModelOutputWithPast(
  410. last_hidden_state=hidden_states,
  411. past_key_values=past_key_values,
  412. )
  413. # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm
  414. class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin):
  415. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  416. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm
  417. def __init__(self, config):
  418. super().__init__(config)
  419. self.model = StableLmModel(config)
  420. self.vocab_size = config.vocab_size
  421. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  422. # Initialize weights and apply final processing
  423. self.post_init()
  424. @can_return_tuple
  425. @auto_docstring
  426. # Ignore copy
  427. def forward(
  428. self,
  429. input_ids: torch.LongTensor | None = None,
  430. attention_mask: torch.Tensor | None = None,
  431. position_ids: torch.LongTensor | None = None,
  432. past_key_values: Cache | None = None,
  433. inputs_embeds: torch.FloatTensor | None = None,
  434. labels: torch.LongTensor | None = None,
  435. use_cache: bool | None = None,
  436. logits_to_keep: int | torch.Tensor = 0,
  437. **kwargs: Unpack[TransformersKwargs],
  438. ) -> CausalLMOutputWithPast:
  439. r"""
  440. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  441. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  442. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  443. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  444. Example:
  445. ```python
  446. >>> from transformers import AutoTokenizer, StableLmForCausalLM
  447. >>> model = StableLmForCausalLM.from_pretrained("adept/persimmon-8b-base")
  448. >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
  449. >>> prompt = "human: Hey, what should I eat for dinner?"
  450. >>> inputs = tokenizer(prompt, return_tensors="pt")
  451. >>> # Generate
  452. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  453. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  454. 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
  455. ```"""
  456. outputs: BaseModelOutputWithPast = self.model(
  457. input_ids=input_ids,
  458. attention_mask=attention_mask,
  459. position_ids=position_ids,
  460. past_key_values=past_key_values,
  461. inputs_embeds=inputs_embeds,
  462. use_cache=use_cache,
  463. **kwargs,
  464. )
  465. hidden_states = outputs.last_hidden_state
  466. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  467. logits = self.lm_head(hidden_states[:, slice_indices, :])
  468. loss = None
  469. if labels is not None:
  470. loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
  471. return CausalLMOutputWithPast(
  472. loss=loss,
  473. logits=logits,
  474. past_key_values=outputs.past_key_values,
  475. hidden_states=outputs.hidden_states,
  476. attentions=outputs.attentions,
  477. )
  478. class StableLmForSequenceClassification(GenericForSequenceClassification, StableLmPreTrainedModel): ...
  479. class StableLmForTokenClassification(GenericForTokenClassification, StableLmPreTrainedModel): ...
  480. __all__ = [
  481. "StableLmForCausalLM",
  482. "StableLmModel",
  483. "StableLmPreTrainedModel",
  484. "StableLmForSequenceClassification",
  485. "StableLmForTokenClassification",
  486. ]