modular_cohere.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright 2024 Cohere team. All rights reserved.
  2. #
  3. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  4. # and OPT implementations in this library. It has been modified from its
  5. # original forms to accommodate minor architectural differences compared
  6. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. # This file is based on the LLama model definition file in transformers
  20. """PyTorch Cohere model."""
  21. from collections.abc import Callable
  22. import torch
  23. from torch import nn
  24. from ...cache_utils import Cache
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  28. from ...modeling_rope_utils import dynamic_rope_update
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  32. from ...utils.generic import maybe_autocast
  33. from ..llama.modeling_llama import (
  34. LlamaAttention,
  35. LlamaForCausalLM,
  36. LlamaMLP,
  37. LlamaModel,
  38. LlamaRotaryEmbedding,
  39. eager_attention_forward,
  40. )
  41. from .configuration_cohere import CohereConfig
  42. logger = logging.get_logger(__name__)
  43. class CohereLayerNorm(nn.Module):
  44. def __init__(self, hidden_size=None, eps=1e-5, bias=False):
  45. """The hidden size can be a tuple or an int. The tuple is used for QKNorm to normalize across head_dim"""
  46. super().__init__()
  47. self.weight = nn.Parameter(torch.ones(hidden_size))
  48. self.variance_epsilon = eps
  49. def forward(self, hidden_states):
  50. input_dtype = hidden_states.dtype
  51. hidden_states = hidden_states.to(torch.float32)
  52. mean = hidden_states.mean(-1, keepdim=True)
  53. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  54. hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon)
  55. hidden_states = self.weight.to(torch.float32) * hidden_states
  56. return hidden_states.to(input_dtype)
  57. class CohereRotaryEmbedding(LlamaRotaryEmbedding):
  58. @torch.no_grad()
  59. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  60. def forward(self, x, position_ids):
  61. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  62. position_ids_expanded = position_ids[:, None, :].float()
  63. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  64. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  65. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  66. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  67. cos = emb.cos() * self.attention_scaling
  68. sin = emb.sin() * self.attention_scaling
  69. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  70. def rotate_half(x):
  71. # Split and rotate. Note that this function is different from e.g. Llama.
  72. x1 = x[..., ::2]
  73. x2 = x[..., 1::2]
  74. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  75. return rot_x
  76. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  77. """Applies Rotary Position Embedding to the query and key tensors.
  78. Args:
  79. q (`torch.Tensor`): The query tensor.
  80. k (`torch.Tensor`): The key tensor.
  81. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  82. sin (`torch.Tensor`): The sine part of the rotary embedding.
  83. unsqueeze_dim (`int`, *optional*, defaults to 1):
  84. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  85. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  86. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  87. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  88. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  89. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  90. Returns:
  91. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  92. """
  93. dtype = q.dtype
  94. q = q.float()
  95. k = k.float()
  96. cos = cos.unsqueeze(unsqueeze_dim)
  97. sin = sin.unsqueeze(unsqueeze_dim)
  98. q_embed = (q * cos) + (rotate_half(q) * sin)
  99. k_embed = (k * cos) + (rotate_half(k) * sin)
  100. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  101. class CohereMLP(LlamaMLP):
  102. def __init__(self, config):
  103. super().__init__(config)
  104. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  105. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  106. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  107. class CohereAttention(LlamaAttention):
  108. """Multi-headed attention from 'Attention Is All You Need' paper"""
  109. def __init__(self, config: CohereConfig, layer_idx: int | None = None):
  110. super().__init__(config, layer_idx)
  111. self.use_qk_norm = config.use_qk_norm
  112. if self.use_qk_norm:
  113. # When sharding the model using Tensor Parallelism, need to be careful to use n_local_heads
  114. self.q_norm = CohereLayerNorm(
  115. hidden_size=(config.num_attention_heads, self.head_dim), eps=config.layer_norm_eps
  116. )
  117. self.k_norm = CohereLayerNorm(
  118. hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps
  119. )
  120. def forward(
  121. self,
  122. hidden_states: torch.Tensor,
  123. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  124. attention_mask: torch.Tensor | None,
  125. past_key_values: Cache | None = None,
  126. **kwargs: Unpack[FlashAttentionKwargs],
  127. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  128. input_shape = hidden_states.shape[:-1]
  129. hidden_shape = (*input_shape, -1, self.head_dim)
  130. query_states = self.q_proj(hidden_states).view(hidden_shape)
  131. key_states = self.k_proj(hidden_states).view(hidden_shape)
  132. value_states = self.v_proj(hidden_states).view(hidden_shape)
  133. if self.use_qk_norm: # main diff from Llama
  134. query_states = self.q_norm(query_states)
  135. key_states = self.k_norm(key_states)
  136. query_states = query_states.transpose(1, 2)
  137. key_states = key_states.transpose(1, 2)
  138. value_states = value_states.transpose(1, 2)
  139. cos, sin = position_embeddings
  140. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  141. if past_key_values is not None:
  142. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  143. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  144. self.config._attn_implementation, eager_attention_forward
  145. )
  146. attn_output, attn_weights = attention_interface(
  147. self,
  148. query_states,
  149. key_states,
  150. value_states,
  151. attention_mask,
  152. dropout=0.0 if not self.training else self.attention_dropout,
  153. scaling=self.scaling,
  154. **kwargs,
  155. )
  156. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  157. attn_output = self.o_proj(attn_output)
  158. return attn_output, attn_weights
  159. class CohereDecoderLayer(GradientCheckpointingLayer):
  160. def __init__(self, config: CohereConfig, layer_idx: int):
  161. super().__init__()
  162. self.hidden_size = config.hidden_size
  163. self.self_attn = CohereAttention(config=config, layer_idx=layer_idx)
  164. self.mlp = CohereMLP(config)
  165. self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  166. def forward(
  167. self,
  168. hidden_states: torch.Tensor,
  169. attention_mask: torch.Tensor | None = None,
  170. position_ids: torch.LongTensor | None = None,
  171. past_key_values: Cache | None = None,
  172. use_cache: bool | None = False,
  173. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  174. **kwargs: Unpack[FlashAttentionKwargs],
  175. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  176. """
  177. Args:
  178. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  179. attention_mask (`torch.FloatTensor`, *optional*):
  180. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  181. query_sequence_length, key_sequence_length)` if default attention is used.
  182. past_key_values (`Cache`, *optional*): cached past key and value projection states
  183. output_attentions (`bool`, *optional*):
  184. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  185. returned tensors for more detail.
  186. use_cache (`bool`, *optional*):
  187. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  188. (see `past_key_values`).
  189. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  190. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  191. with `head_dim` being the embedding dimension of each attention head.
  192. """
  193. residual = hidden_states
  194. hidden_states = self.input_layernorm(hidden_states)
  195. hidden_states_attention, _ = self.self_attn(
  196. hidden_states=hidden_states,
  197. attention_mask=attention_mask,
  198. position_ids=position_ids,
  199. past_key_values=past_key_values,
  200. use_cache=use_cache,
  201. position_embeddings=position_embeddings,
  202. **kwargs,
  203. )
  204. hidden_states_mlp = self.mlp(hidden_states)
  205. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  206. return hidden_states
  207. class CohereModel(LlamaModel):
  208. def __init__(self, config: CohereConfig):
  209. super().__init__(config)
  210. self.layers = nn.ModuleList(
  211. [CohereDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  212. )
  213. self.norm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  214. class CohereForCausalLM(LlamaForCausalLM):
  215. def __init__(self, config):
  216. super().__init__(config)
  217. self.model = CohereModel(config)
  218. self.logit_scale = config.logit_scale
  219. self.tie_word_embeddings = config.tie_word_embeddings
  220. @can_return_tuple
  221. @auto_docstring
  222. def forward(
  223. self,
  224. input_ids: torch.LongTensor | None = None,
  225. attention_mask: torch.Tensor | None = None,
  226. position_ids: torch.LongTensor | None = None,
  227. past_key_values: Cache | None = None,
  228. inputs_embeds: torch.FloatTensor | None = None,
  229. labels: torch.LongTensor | None = None,
  230. use_cache: bool | None = None,
  231. logits_to_keep: int | torch.Tensor = 0,
  232. **kwargs: Unpack[TransformersKwargs],
  233. ) -> CausalLMOutputWithPast:
  234. r"""
  235. Example:
  236. ```python
  237. >> from transformers import AutoTokenizer, CohereForCausalLM
  238. >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01")
  239. >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  240. >> prompt = "Hey, are you conscious? Can you talk to me?"
  241. >> inputs = tokenizer(prompt, return_tensors="pt")
  242. >> # Generate
  243. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  244. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  245. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  246. ```"""
  247. outputs: BaseModelOutputWithPast = self.model(
  248. input_ids=input_ids,
  249. attention_mask=attention_mask,
  250. position_ids=position_ids,
  251. past_key_values=past_key_values,
  252. inputs_embeds=inputs_embeds,
  253. use_cache=use_cache,
  254. **kwargs,
  255. )
  256. hidden_states = outputs.last_hidden_state
  257. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  258. logits = self.lm_head(hidden_states[:, slice_indices, :])
  259. logits = logits * self.logit_scale # main diff from Llama
  260. loss = None
  261. if labels is not None:
  262. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  263. return CausalLMOutputWithPast(
  264. loss=loss,
  265. logits=logits,
  266. past_key_values=outputs.past_key_values,
  267. hidden_states=outputs.hidden_states,
  268. attentions=outputs.attentions,
  269. )
  270. __all__ = [
  271. "CohereForCausalLM",
  272. "CohereModel",
  273. "CoherePreTrainedModel", # noqa: F822
  274. ]