modeling_cohere2.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/cohere2/modular_cohere2.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_cohere2.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
  8. #
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from collections.abc import Callable
  22. from typing import Optional
  23. import torch
  24. import torch.nn as nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache
  27. from ...generation import GenerationMixin
  28. from ...integrations import use_kernelized_func
  29. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_cohere2 import Cohere2Config
  39. class Cohere2RotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: Cohere2Config, device=None):
  42. super().__init__()
  43. self.max_seq_len_cached = config.max_position_embeddings
  44. self.original_max_seq_len = config.max_position_embeddings
  45. self.config = config
  46. self.rope_type = self.config.rope_parameters["rope_type"]
  47. rope_init_fn: Callable = self.compute_default_rope_parameters
  48. if self.rope_type != "default":
  49. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  50. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  51. self.register_buffer("inv_freq", inv_freq, persistent=False)
  52. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  53. @staticmethod
  54. def compute_default_rope_parameters(
  55. config: Cohere2Config | None = None,
  56. device: Optional["torch.device"] = None,
  57. seq_len: int | None = None,
  58. ) -> tuple["torch.Tensor", float]:
  59. """
  60. Computes the inverse frequencies according to the original RoPE implementation
  61. Args:
  62. config ([`~transformers.PreTrainedConfig`]):
  63. The model configuration.
  64. device (`torch.device`):
  65. The device to use for initialization of the inverse frequencies.
  66. seq_len (`int`, *optional*):
  67. The current sequence length. Unused for this type of RoPE.
  68. Returns:
  69. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  70. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  71. """
  72. base = config.rope_parameters["rope_theta"]
  73. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  74. attention_factor = 1.0 # Unused in this type of RoPE
  75. # Compute the inverse frequencies
  76. inv_freq = 1.0 / (
  77. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  78. )
  79. return inv_freq, attention_factor
  80. @torch.no_grad()
  81. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  82. def forward(self, x, position_ids):
  83. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  84. position_ids_expanded = position_ids[:, None, :].float()
  85. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  86. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  87. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  88. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  89. cos = emb.cos() * self.attention_scaling
  90. sin = emb.sin() * self.attention_scaling
  91. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  92. class Cohere2LayerNorm(nn.Module):
  93. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  94. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  95. super().__init__()
  96. self.weight = nn.Parameter(torch.ones(hidden_size))
  97. self.variance_epsilon = eps
  98. def forward(self, hidden_states):
  99. input_dtype = hidden_states.dtype
  100. hidden_states = hidden_states.to(torch.float32)
  101. mean = hidden_states.mean(-1, keepdim=True)
  102. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  103. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  104. hidden_states = self.weight.to(torch.float32) * hidden_states
  105. return hidden_states.to(input_dtype)
  106. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  107. """
  108. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  109. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  110. """
  111. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  112. if n_rep == 1:
  113. return hidden_states
  114. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  115. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  116. def eager_attention_forward(
  117. module: nn.Module,
  118. query: torch.Tensor,
  119. key: torch.Tensor,
  120. value: torch.Tensor,
  121. attention_mask: torch.Tensor | None,
  122. scaling: float,
  123. dropout: float = 0.0,
  124. **kwargs: Unpack[TransformersKwargs],
  125. ):
  126. key_states = repeat_kv(key, module.num_key_value_groups)
  127. value_states = repeat_kv(value, module.num_key_value_groups)
  128. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  129. if attention_mask is not None:
  130. attn_weights = attn_weights + attention_mask
  131. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  132. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  133. attn_output = torch.matmul(attn_weights, value_states)
  134. attn_output = attn_output.transpose(1, 2).contiguous()
  135. return attn_output, attn_weights
  136. def rotate_half(x):
  137. # Split and rotate. Note that this function is different from e.g. Llama.
  138. x1 = x[..., ::2]
  139. x2 = x[..., 1::2]
  140. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  141. return rot_x
  142. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  143. """Applies Rotary Position Embedding to the query and key tensors.
  144. Args:
  145. q (`torch.Tensor`): The query tensor.
  146. k (`torch.Tensor`): The key tensor.
  147. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  148. sin (`torch.Tensor`): The sine part of the rotary embedding.
  149. unsqueeze_dim (`int`, *optional*, defaults to 1):
  150. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  151. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  152. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  153. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  154. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  155. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  156. Returns:
  157. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  158. """
  159. dtype = q.dtype
  160. q = q.float()
  161. k = k.float()
  162. cos = cos.unsqueeze(unsqueeze_dim)
  163. sin = sin.unsqueeze(unsqueeze_dim)
  164. q_embed = (q * cos) + (rotate_half(q) * sin)
  165. k_embed = (k * cos) + (rotate_half(k) * sin)
  166. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  167. @use_kernelized_func(apply_rotary_pos_emb)
  168. class Cohere2Attention(nn.Module):
  169. """Multi-headed attention from 'Attention Is All You Need' paper"""
  170. def __init__(self, config: Cohere2Config, layer_idx: int | None = None):
  171. super().__init__()
  172. self.config = config
  173. self.layer_idx = layer_idx
  174. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  175. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  176. self.scaling = self.head_dim**-0.5
  177. self.attention_dropout = config.attention_dropout
  178. self.is_causal = True
  179. layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  180. self.sliding_window = config.sliding_window if layer_type == "sliding_attention" else None
  181. self.q_proj = nn.Linear(
  182. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  183. )
  184. self.k_proj = nn.Linear(
  185. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  186. )
  187. self.v_proj = nn.Linear(
  188. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  189. )
  190. self.o_proj = nn.Linear(
  191. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  192. )
  193. def forward(
  194. self,
  195. hidden_states: torch.Tensor,
  196. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  197. attention_mask: torch.Tensor | None,
  198. past_key_values: Cache | None = None,
  199. **kwargs: Unpack[TransformersKwargs],
  200. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  201. input_shape = hidden_states.shape[:-1]
  202. hidden_shape = (*input_shape, -1, self.head_dim)
  203. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  204. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  205. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  206. cos, sin = position_embeddings
  207. if self.sliding_window is not None:
  208. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  209. if past_key_values is not None:
  210. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  211. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  212. self.config._attn_implementation, eager_attention_forward
  213. )
  214. attn_output, attn_weights = attention_interface(
  215. self,
  216. query_states,
  217. key_states,
  218. value_states,
  219. attention_mask,
  220. dropout=0.0 if not self.training else self.attention_dropout,
  221. scaling=self.scaling,
  222. sliding_window=self.sliding_window,
  223. **kwargs,
  224. )
  225. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  226. attn_output = self.o_proj(attn_output)
  227. return attn_output, attn_weights
  228. class Cohere2MLP(nn.Module):
  229. def __init__(self, config):
  230. super().__init__()
  231. self.config = config
  232. self.hidden_size = config.hidden_size
  233. self.intermediate_size = config.intermediate_size
  234. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  235. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  236. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  237. self.act_fn = ACT2FN[config.hidden_act]
  238. def forward(self, x):
  239. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  240. return down_proj
  241. class Cohere2DecoderLayer(GradientCheckpointingLayer):
  242. def __init__(self, config: Cohere2Config, layer_idx: int):
  243. super().__init__()
  244. self.hidden_size = config.hidden_size
  245. self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
  246. self.mlp = Cohere2MLP(config)
  247. self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  248. def forward(
  249. self,
  250. hidden_states: torch.Tensor,
  251. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  252. attention_mask: torch.Tensor | None = None,
  253. past_key_values: Cache | None = None,
  254. use_cache: bool | None = False,
  255. **kwargs: Unpack[TransformersKwargs],
  256. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  257. """
  258. Args:
  259. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  260. attention_mask (`torch.FloatTensor`, *optional*):
  261. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  262. query_sequence_length, key_sequence_length)` if default attention is used.
  263. past_key_values (`Cache`, *optional*): cached past key and value projection states
  264. output_attentions (`bool`, *optional*):
  265. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  266. returned tensors for more detail.
  267. use_cache (`bool`, *optional*):
  268. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  269. (see `past_key_values`).
  270. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  271. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  272. with `head_dim` being the embedding dimension of each attention head.
  273. """
  274. residual = hidden_states
  275. hidden_states = self.input_layernorm(hidden_states)
  276. hidden_states_attention, _ = self.self_attn(
  277. hidden_states=hidden_states,
  278. position_embeddings=position_embeddings,
  279. attention_mask=attention_mask,
  280. past_key_values=past_key_values,
  281. use_cache=use_cache,
  282. **kwargs,
  283. )
  284. hidden_states_mlp = self.mlp(hidden_states)
  285. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  286. return hidden_states
  287. @auto_docstring
  288. class Cohere2PreTrainedModel(PreTrainedModel):
  289. config: Cohere2Config
  290. base_model_prefix = "model"
  291. supports_gradient_checkpointing = True
  292. _no_split_modules = ["Cohere2DecoderLayer"]
  293. _skip_keys_device_placement = ["past_key_values"]
  294. _supports_flash_attn = True
  295. _supports_sdpa = True
  296. _supports_flex_attn = True
  297. _can_compile_fullgraph = True
  298. _supports_attention_backend = True
  299. _can_record_outputs = {
  300. "hidden_states": Cohere2DecoderLayer,
  301. "attentions": Cohere2Attention,
  302. }
  303. @auto_docstring
  304. class Cohere2Model(Cohere2PreTrainedModel):
  305. def __init__(self, config: Cohere2Config):
  306. super().__init__(config)
  307. self.padding_idx = config.pad_token_id
  308. self.vocab_size = config.vocab_size
  309. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  310. self.layers = nn.ModuleList(
  311. [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  312. )
  313. self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  314. self.rotary_emb = Cohere2RotaryEmbedding(config)
  315. self.gradient_checkpointing = False
  316. # Initialize weights and apply final processing
  317. self.post_init()
  318. @merge_with_config_defaults
  319. @capture_outputs
  320. @auto_docstring
  321. def forward(
  322. self,
  323. input_ids: torch.LongTensor | None = None,
  324. attention_mask: torch.Tensor | None = None,
  325. position_ids: torch.LongTensor | None = None,
  326. past_key_values: Cache | None = None,
  327. inputs_embeds: torch.FloatTensor | None = None,
  328. use_cache: bool | None = None,
  329. **kwargs: Unpack[TransformersKwargs],
  330. ) -> BaseModelOutputWithPast:
  331. if (input_ids is None) ^ (inputs_embeds is not None):
  332. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  333. if inputs_embeds is None:
  334. inputs_embeds = self.embed_tokens(input_ids)
  335. if use_cache and past_key_values is None:
  336. past_key_values = DynamicCache(config=self.config)
  337. if position_ids is None:
  338. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  339. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  340. position_ids = position_ids.unsqueeze(0)
  341. if not isinstance(causal_mask_mapping := attention_mask, dict):
  342. mask_kwargs = {
  343. "config": self.config,
  344. "inputs_embeds": inputs_embeds,
  345. "attention_mask": attention_mask,
  346. "past_key_values": past_key_values,
  347. "position_ids": position_ids,
  348. }
  349. causal_mask_mapping = {
  350. "full_attention": create_causal_mask(**mask_kwargs),
  351. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  352. }
  353. hidden_states = inputs_embeds
  354. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  355. for i, decoder_layer in enumerate(self.layers):
  356. hidden_states = decoder_layer(
  357. hidden_states,
  358. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  359. position_embeddings=position_embeddings,
  360. past_key_values=past_key_values,
  361. use_cache=use_cache,
  362. position_ids=position_ids,
  363. **kwargs,
  364. )
  365. hidden_states = self.norm(hidden_states)
  366. return BaseModelOutputWithPast(
  367. last_hidden_state=hidden_states,
  368. past_key_values=past_key_values,
  369. )
  370. @auto_docstring
  371. class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
  372. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  373. _tp_plan = {"lm_head": "colwise_gather_output"}
  374. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  375. def __init__(self, config):
  376. super().__init__(config)
  377. self.model = Cohere2Model(config)
  378. self.vocab_size = config.vocab_size
  379. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  380. self.logit_scale = config.logit_scale
  381. self.tie_word_embeddings = config.tie_word_embeddings
  382. # Initialize weights and apply final processing
  383. self.post_init()
  384. @can_return_tuple
  385. @auto_docstring
  386. def forward(
  387. self,
  388. input_ids: torch.LongTensor | None = None,
  389. attention_mask: torch.Tensor | None = None,
  390. position_ids: torch.LongTensor | None = None,
  391. past_key_values: Cache | None = None,
  392. inputs_embeds: torch.FloatTensor | None = None,
  393. labels: torch.LongTensor | None = None,
  394. use_cache: bool | None = None,
  395. logits_to_keep: int | torch.Tensor = 0,
  396. **kwargs: Unpack[TransformersKwargs],
  397. ) -> CausalLMOutputWithPast:
  398. r"""
  399. Example:
  400. ```python
  401. >> from transformers import AutoTokenizer, Cohere2ForCausalLM
  402. >> model = Cohere2ForCausalLM.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
  403. >> tokenizer = AutoTokenizer.from_pretrained("Cohere2ForAI/c4ai-command-r-v01")
  404. >> prompt = "Hey, are you conscious? Can you talk to me?"
  405. >> inputs = tokenizer(prompt, return_tensors="pt")
  406. >> # Generate
  407. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  408. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  409. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  410. ```"""
  411. outputs: BaseModelOutputWithPast = self.model(
  412. input_ids=input_ids,
  413. attention_mask=attention_mask,
  414. position_ids=position_ids,
  415. past_key_values=past_key_values,
  416. inputs_embeds=inputs_embeds,
  417. use_cache=use_cache,
  418. **kwargs,
  419. )
  420. hidden_states = outputs.last_hidden_state
  421. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  422. logits = self.lm_head(hidden_states[:, slice_indices, :])
  423. logits = logits * self.logit_scale # main diff from Llama
  424. loss = None
  425. if labels is not None:
  426. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  427. return CausalLMOutputWithPast(
  428. loss=loss,
  429. logits=logits,
  430. past_key_values=outputs.past_key_values,
  431. hidden_states=outputs.hidden_states,
  432. attentions=outputs.attentions,
  433. )
  434. __all__ = ["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]