modeling_cohere.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/cohere/modular_cohere.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_cohere.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2024 Cohere team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. # This file is based on the LLama model definition file in transformers
  26. from collections.abc import Callable
  27. from typing import Optional
  28. import torch
  29. from torch import nn
  30. from ...activations import ACT2FN
  31. from ...cache_utils import Cache, DynamicCache
  32. from ...generation import GenerationMixin
  33. from ...integrations import use_kernelized_func
  34. from ...masking_utils import create_causal_mask
  35. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  36. from ...modeling_layers import GradientCheckpointingLayer
  37. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  38. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  39. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  40. from ...processing_utils import Unpack
  41. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  42. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  43. from ...utils.output_capturing import capture_outputs
  44. from .configuration_cohere import CohereConfig
  45. class CohereLayerNorm(nn.Module):
  46. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  47. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  48. super().__init__()
  49. self.weight = nn.Parameter(torch.ones(hidden_size))
  50. self.variance_epsilon = eps
  51. def forward(self, hidden_states):
  52. input_dtype = hidden_states.dtype
  53. hidden_states = hidden_states.to(torch.float32)
  54. mean = hidden_states.mean(-1, keepdim=True)
  55. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  56. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  57. hidden_states = self.weight.to(torch.float32) * hidden_states
  58. return hidden_states.to(input_dtype)
  59. class CohereRotaryEmbedding(nn.Module):
  60. inv_freq: torch.Tensor # fix linting for `register_buffer`
  61. def __init__(self, config: CohereConfig, device=None):
  62. super().__init__()
  63. self.max_seq_len_cached = config.max_position_embeddings
  64. self.original_max_seq_len = config.max_position_embeddings
  65. self.config = config
  66. self.rope_type = self.config.rope_parameters["rope_type"]
  67. rope_init_fn: Callable = self.compute_default_rope_parameters
  68. if self.rope_type != "default":
  69. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  70. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  71. self.register_buffer("inv_freq", inv_freq, persistent=False)
  72. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  73. @staticmethod
  74. def compute_default_rope_parameters(
  75. config: CohereConfig | None = None,
  76. device: Optional["torch.device"] = None,
  77. seq_len: int | None = None,
  78. ) -> tuple["torch.Tensor", float]:
  79. """
  80. Computes the inverse frequencies according to the original RoPE implementation
  81. Args:
  82. config ([`~transformers.PreTrainedConfig`]):
  83. The model configuration.
  84. device (`torch.device`):
  85. The device to use for initialization of the inverse frequencies.
  86. seq_len (`int`, *optional*):
  87. The current sequence length. Unused for this type of RoPE.
  88. Returns:
  89. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  90. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  91. """
  92. base = config.rope_parameters["rope_theta"]
  93. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  94. attention_factor = 1.0 # Unused in this type of RoPE
  95. # Compute the inverse frequencies
  96. inv_freq = 1.0 / (
  97. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  98. )
  99. return inv_freq, attention_factor
  100. @torch.no_grad()
  101. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  102. def forward(self, x, position_ids):
  103. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  104. position_ids_expanded = position_ids[:, None, :].float()
  105. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  106. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  107. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  108. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  109. cos = emb.cos() * self.attention_scaling
  110. sin = emb.sin() * self.attention_scaling
  111. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  112. class CohereMLP(nn.Module):
  113. def __init__(self, config):
  114. super().__init__()
  115. self.config = config
  116. self.hidden_size = config.hidden_size
  117. self.intermediate_size = config.intermediate_size
  118. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  119. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  120. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  121. self.act_fn = ACT2FN[config.hidden_act]
  122. def forward(self, x):
  123. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  124. return down_proj
  125. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  126. """
  127. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  128. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  129. """
  130. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  131. if n_rep == 1:
  132. return hidden_states
  133. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  134. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  135. def eager_attention_forward(
  136. module: nn.Module,
  137. query: torch.Tensor,
  138. key: torch.Tensor,
  139. value: torch.Tensor,
  140. attention_mask: torch.Tensor | None,
  141. scaling: float,
  142. dropout: float = 0.0,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ):
  145. key_states = repeat_kv(key, module.num_key_value_groups)
  146. value_states = repeat_kv(value, module.num_key_value_groups)
  147. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  148. if attention_mask is not None:
  149. attn_weights = attn_weights + attention_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. def rotate_half(x):
  156. # Split and rotate. Note that this function is different from e.g. Llama.
  157. x1 = x[..., ::2]
  158. x2 = x[..., 1::2]
  159. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  160. return rot_x
  161. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  162. """Applies Rotary Position Embedding to the query and key tensors.
  163. Args:
  164. q (`torch.Tensor`): The query tensor.
  165. k (`torch.Tensor`): The key tensor.
  166. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  167. sin (`torch.Tensor`): The sine part of the rotary embedding.
  168. unsqueeze_dim (`int`, *optional*, defaults to 1):
  169. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  170. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  171. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  172. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  173. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  174. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  175. Returns:
  176. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  177. """
  178. dtype = q.dtype
  179. q = q.float()
  180. k = k.float()
  181. cos = cos.unsqueeze(unsqueeze_dim)
  182. sin = sin.unsqueeze(unsqueeze_dim)
  183. q_embed = (q * cos) + (rotate_half(q) * sin)
  184. k_embed = (k * cos) + (rotate_half(k) * sin)
  185. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  186. @use_kernelized_func(apply_rotary_pos_emb)
  187. class CohereAttention(nn.Module):
  188. """Multi-headed attention from 'Attention Is All You Need' paper"""
  189. def __init__(self, config: CohereConfig, layer_idx: int | None = None):
  190. super().__init__()
  191. self.config = config
  192. self.layer_idx = layer_idx
  193. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  194. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  195. self.scaling = self.head_dim**-0.5
  196. self.attention_dropout = config.attention_dropout
  197. self.is_causal = True
  198. self.q_proj = nn.Linear(
  199. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  200. )
  201. self.k_proj = nn.Linear(
  202. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  203. )
  204. self.v_proj = nn.Linear(
  205. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  206. )
  207. self.o_proj = nn.Linear(
  208. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  209. )
  210. self.use_qk_norm = config.use_qk_norm
  211. if self.use_qk_norm:
  212. # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
  213. self.q_norm = CohereLayerNorm(
  214. hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
  215. )
  216. self.k_norm = CohereLayerNorm(
  217. hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
  218. )
  219. def forward(
  220. self,
  221. hidden_states: torch.Tensor,
  222. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  223. attention_mask: torch.Tensor | None,
  224. past_key_values: Cache | None = None,
  225. **kwargs: Unpack[FlashAttentionKwargs],
  226. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  227. input_shape = hidden_states.shape[:-1]
  228. hidden_shape = (*input_shape, -1, self.head_dim)
  229. query_states = self.q_proj(hidden_states).view(hidden_shape)
  230. key_states = self.k_proj(hidden_states).view(hidden_shape)
  231. value_states = self.v_proj(hidden_states).view(hidden_shape)
  232. if self.use_qk_norm: # main diff from Llama
  233. query_states = self.q_norm(query_states)
  234. key_states = self.k_norm(key_states)
  235. query_states = query_states.transpose(1, 2)
  236. key_states = key_states.transpose(1, 2)
  237. value_states = value_states.transpose(1, 2)
  238. cos, sin = position_embeddings
  239. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  240. if past_key_values is not None:
  241. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  242. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  243. self.config._attn_implementation, eager_attention_forward
  244. )
  245. attn_output, attn_weights = attention_interface(
  246. self,
  247. query_states,
  248. key_states,
  249. value_states,
  250. attention_mask,
  251. dropout=0.0 if not self.training else self.attention_dropout,
  252. scaling=self.scaling,
  253. **kwargs,
  254. )
  255. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  256. attn_output = self.o_proj(attn_output)
  257. return attn_output, attn_weights
  258. class CohereDecoderLayer(GradientCheckpointingLayer):
  259. def __init__(self, config: CohereConfig, layer_idx: int):
  260. super().__init__()
  261. self.hidden_size = config.hidden_size
  262. self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
  263. self.mlp = CohereMLP(config)
  264. self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. attention_mask: torch.Tensor | None = None,
  269. position_ids: torch.LongTensor | None = None,
  270. past_key_values: Cache | None = None,
  271. use_cache: bool | None = False,
  272. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  273. **kwargs: Unpack[FlashAttentionKwargs],
  274. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  275. """
  276. Args:
  277. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  278. attention_mask (`torch.FloatTensor`, *optional*):
  279. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  280. query_sequence_length, key_sequence_length)` if default attention is used.
  281. past_key_values (`Cache`, *optional*): cached past key and value projection states
  282. output_attentions (`bool`, *optional*):
  283. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  284. returned tensors for more detail.
  285. use_cache (`bool`, *optional*):
  286. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  287. (see `past_key_values`).
  288. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  289. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  290. with `head_dim` being the embedding dimension of each attention head.
  291. """
  292. residual = hidden_states
  293. hidden_states = self.input_layernorm(hidden_states)
  294. hidden_states_attention, _ = self.self_attn(
  295. hidden_states=hidden_states,
  296. attention_mask=attention_mask,
  297. position_ids=position_ids,
  298. past_key_values=past_key_values,
  299. use_cache=use_cache,
  300. position_embeddings=position_embeddings,
  301. **kwargs,
  302. )
  303. hidden_states_mlp = self.mlp(hidden_states)
  304. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  305. return hidden_states
  306. @auto_docstring
  307. class CoherePreTrainedModel(PreTrainedModel):
  308. config: CohereConfig
  309. base_model_prefix = "model"
  310. supports_gradient_checkpointing = True
  311. _no_split_modules = ["CohereDecoderLayer"]
  312. _skip_keys_device_placement = ["past_key_values"]
  313. _supports_flash_attn = True
  314. _supports_sdpa = True
  315. _supports_flex_attn = True
  316. _can_compile_fullgraph = True
  317. _supports_attention_backend = True
  318. _can_record_outputs = {
  319. "hidden_states": CohereDecoderLayer,
  320. "attentions": CohereAttention,
  321. }
  322. @auto_docstring
  323. class CohereModel(CoherePreTrainedModel):
  324. def __init__(self, config: CohereConfig):
  325. super().__init__(config)
  326. self.padding_idx = config.pad_token_id
  327. self.vocab_size = config.vocab_size
  328. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  329. self.layers = nn.ModuleList(
  330. [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  331. )
  332. self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  333. self.rotary_emb = CohereRotaryEmbedding(config=config)
  334. self.gradient_checkpointing = False
  335. # Initialize weights and apply final processing
  336. self.post_init()
  337. @merge_with_config_defaults
  338. @capture_outputs
  339. @auto_docstring
  340. def forward(
  341. self,
  342. input_ids: torch.LongTensor | None = None,
  343. attention_mask: torch.Tensor | None = None,
  344. position_ids: torch.LongTensor | None = None,
  345. past_key_values: Cache | None = None,
  346. inputs_embeds: torch.FloatTensor | None = None,
  347. use_cache: bool | None = None,
  348. **kwargs: Unpack[TransformersKwargs],
  349. ) -> BaseModelOutputWithPast:
  350. if (input_ids is None) ^ (inputs_embeds is not None):
  351. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  352. if inputs_embeds is None:
  353. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  354. if use_cache and past_key_values is None:
  355. past_key_values = DynamicCache(config=self.config)
  356. if position_ids is None:
  357. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  358. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  359. position_ids = position_ids.unsqueeze(0)
  360. causal_mask = create_causal_mask(
  361. config=self.config,
  362. inputs_embeds=inputs_embeds,
  363. attention_mask=attention_mask,
  364. past_key_values=past_key_values,
  365. position_ids=position_ids,
  366. )
  367. hidden_states = inputs_embeds
  368. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  369. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  370. hidden_states = decoder_layer(
  371. hidden_states,
  372. attention_mask=causal_mask,
  373. position_embeddings=position_embeddings,
  374. position_ids=position_ids,
  375. past_key_values=past_key_values,
  376. use_cache=use_cache,
  377. **kwargs,
  378. )
  379. hidden_states = self.norm(hidden_states)
  380. return BaseModelOutputWithPast(
  381. last_hidden_state=hidden_states,
  382. past_key_values=past_key_values,
  383. )
  384. @auto_docstring
  385. class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
  386. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  387. _tp_plan = {"lm_head": "colwise_gather_output"}
  388. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  389. def __init__(self, config):
  390. super().__init__(config)
  391. self.model = CohereModel(config)
  392. self.vocab_size = config.vocab_size
  393. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  394. self.logit_scale = config.logit_scale
  395. self.tie_word_embeddings = config.tie_word_embeddings
  396. # Initialize weights and apply final processing
  397. self.post_init()
  398. @can_return_tuple
  399. @auto_docstring
  400. def forward(
  401. self,
  402. input_ids: torch.LongTensor | None = None,
  403. attention_mask: torch.Tensor | None = None,
  404. position_ids: torch.LongTensor | None = None,
  405. past_key_values: Cache | None = None,
  406. inputs_embeds: torch.FloatTensor | None = None,
  407. labels: torch.LongTensor | None = None,
  408. use_cache: bool | None = None,
  409. logits_to_keep: int | torch.Tensor = 0,
  410. **kwargs: Unpack[TransformersKwargs],
  411. ) -> CausalLMOutputWithPast:
  412. r"""
  413. Example:
  414. ```python
  415. >> from transformers import AutoTokenizer, CohereForCausalLM
  416. >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
  417. >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  418. >> prompt = "Hey, are you conscious? Can you talk to me?"
  419. >> inputs = tokenizer(prompt, return_tensors="pt")
  420. >> # Generate
  421. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  422. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  423. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  424. ```"""
  425. outputs: BaseModelOutputWithPast = self.model(
  426. input_ids=input_ids,
  427. attention_mask=attention_mask,
  428. position_ids=position_ids,
  429. past_key_values=past_key_values,
  430. inputs_embeds=inputs_embeds,
  431. use_cache=use_cache,
  432. **kwargs,
  433. )
  434. hidden_states = outputs.last_hidden_state
  435. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  436. logits = self.lm_head(hidden_states[:, slice_indices, :])
  437. logits = logits * self.logit_scale # main diff from Llama
  438. loss = None
  439. if labels is not None:
  440. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  441. return CausalLMOutputWithPast(
  442. loss=loss,
  443. logits=logits,
  444. past_key_values=outputs.past_key_values,
  445. hidden_states=outputs.hidden_states,
  446. attentions=outputs.attentions,
  447. )
  448. __all__ = ["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]