modular_cohere2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # Copyright 2024 Cohere Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections.abc import Callable
  16. import torch
  17. import torch.nn as nn
  18. from huggingface_hub.dataclasses import strict
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...configuration_utils import PreTrainedConfig
  21. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  22. from ...modeling_outputs import BaseModelOutputWithPast
  23. from ...modeling_rope_utils import (
  24. RopeParameters,
  25. dynamic_rope_update,
  26. )
  27. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  28. from ...processing_utils import Unpack
  29. from ...utils import TransformersKwargs, auto_docstring, logging
  30. from ...utils.generic import maybe_autocast
  31. from ..cohere.modeling_cohere import (
  32. CohereAttention,
  33. CohereDecoderLayer,
  34. CohereForCausalLM,
  35. CohereLayerNorm,
  36. CoherePreTrainedModel,
  37. CohereRotaryEmbedding,
  38. apply_rotary_pos_emb,
  39. eager_attention_forward,
  40. )
  41. from ..gemma2.modeling_gemma2 import Gemma2Model
  42. logger = logging.get_logger(__name__)
  43. @auto_docstring(checkpoint="CohereForAI/c4ai-command-r-v01")
  44. @strict
  45. class Cohere2Config(PreTrainedConfig):
  46. r"""
  47. logit_scale (`float`, *optional*, defaults to 0.0625):
  48. The scaling factor for the output logits.
  49. ```python
  50. >>> from transformers import Cohere2Model, Cohere2Config
  51. >>> # Initializing a Cohere Nextmodel configuration
  52. >>> configuration = Cohere2Config()
  53. >>> # Initializing a model from the Cohere2 configuration
  54. >>> model = Cohere2Model(configuration) # doctest: +SKIP
  55. >>> # Accessing the model configuration
  56. >>> configuration = model.config # doctest: +SKIP
  57. ```
  58. """
  59. model_type = "cohere2"
  60. keys_to_ignore_at_inference = ["past_key_values"]
  61. base_model_tp_plan = {
  62. "layers.*.self_attn.q_proj": "colwise",
  63. "layers.*.self_attn.k_proj": "colwise",
  64. "layers.*.self_attn.v_proj": "colwise",
  65. "layers.*.self_attn.o_proj": "rowwise",
  66. "layers.*.mlp.gate_proj": "colwise",
  67. "layers.*.mlp.up_proj": "colwise",
  68. "layers.*.mlp.down_proj": "rowwise",
  69. }
  70. base_model_pp_plan = {
  71. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  72. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  73. "norm": (["hidden_states"], ["hidden_states"]),
  74. }
  75. vocab_size: int = 256000
  76. hidden_size: int = 8192
  77. intermediate_size: int = 22528
  78. logit_scale: float = 0.0625
  79. num_hidden_layers: int = 40
  80. num_attention_heads: int = 64
  81. num_key_value_heads: int | None = None
  82. hidden_act: str = "silu"
  83. max_position_embeddings: int = 8192
  84. initializer_range: float = 0.02
  85. layer_norm_eps: float = 1e-5
  86. use_cache: bool = True
  87. pad_token_id: int | None = 0
  88. bos_token_id: int | None = 5
  89. eos_token_id: int | list[int] | None = 255001
  90. tie_word_embeddings: bool = True
  91. rope_parameters: RopeParameters | dict | None = None
  92. attention_bias: bool = False
  93. attention_dropout: float | int = 0.0
  94. sliding_window: int | None = 4096
  95. layer_types: list[str] | None = None
  96. def __post_init__(self, **kwargs):
  97. if self.num_key_value_heads is None:
  98. self.num_key_value_heads = self.num_attention_heads
  99. # Need to specify head_dim in the config so it can be used in the attention forward functions
  100. self.head_dim = self.hidden_size // self.num_attention_heads
  101. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  102. if self.layer_types is None:
  103. # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
  104. _sliding_window_pattern = kwargs.pop("sliding_window_pattern", 4)
  105. self.layer_types = [
  106. "sliding_attention" if bool((i + 1) % _sliding_window_pattern) else "full_attention"
  107. for i in range(self.num_hidden_layers)
  108. ]
  109. super().__post_init__(**kwargs)
  110. class Cohere2RotaryEmbedding(CohereRotaryEmbedding):
  111. @torch.no_grad()
  112. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  113. def forward(self, x, position_ids):
  114. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  115. position_ids_expanded = position_ids[:, None, :].float()
  116. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  117. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  118. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  119. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  120. cos = emb.cos() * self.attention_scaling
  121. sin = emb.sin() * self.attention_scaling
  122. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  123. class Cohere2LayerNorm(CohereLayerNorm):
  124. pass
  125. class Cohere2Attention(CohereAttention):
  126. """Multi-headed attention from 'Attention Is All You Need' paper"""
  127. def __init__(self, config: Cohere2Config, layer_idx: int | None = None):
  128. nn.Module.__init__(self)
  129. self.config = config
  130. self.layer_idx = layer_idx
  131. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  132. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  133. self.scaling = self.head_dim**-0.5
  134. self.attention_dropout = config.attention_dropout
  135. self.is_causal = True
  136. layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  137. self.sliding_window = config.sliding_window if layer_type == "sliding_attention" else None
  138. self.q_proj = nn.Linear(
  139. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  140. )
  141. self.k_proj = nn.Linear(
  142. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  143. )
  144. self.v_proj = nn.Linear(
  145. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  146. )
  147. self.o_proj = nn.Linear(
  148. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  149. )
  150. def forward(
  151. self,
  152. hidden_states: torch.Tensor,
  153. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  154. attention_mask: torch.Tensor | None,
  155. past_key_values: Cache | None = None,
  156. **kwargs: Unpack[TransformersKwargs],
  157. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  158. input_shape = hidden_states.shape[:-1]
  159. hidden_shape = (*input_shape, -1, self.head_dim)
  160. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  161. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  162. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  163. cos, sin = position_embeddings
  164. if self.sliding_window is not None:
  165. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  166. if past_key_values is not None:
  167. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  168. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  169. self.config._attn_implementation, eager_attention_forward
  170. )
  171. attn_output, attn_weights = attention_interface(
  172. self,
  173. query_states,
  174. key_states,
  175. value_states,
  176. attention_mask,
  177. dropout=0.0 if not self.training else self.attention_dropout,
  178. scaling=self.scaling,
  179. sliding_window=self.sliding_window,
  180. **kwargs,
  181. )
  182. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  183. attn_output = self.o_proj(attn_output)
  184. return attn_output, attn_weights
  185. class Cohere2DecoderLayer(CohereDecoderLayer):
  186. def __init__(self, config: Cohere2Config, layer_idx: int):
  187. super().__init__(config, layer_idx)
  188. def forward(
  189. self,
  190. hidden_states: torch.Tensor,
  191. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  192. attention_mask: torch.Tensor | None = None,
  193. past_key_values: Cache | None = None,
  194. use_cache: bool | None = False,
  195. **kwargs: Unpack[TransformersKwargs],
  196. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  197. residual = hidden_states
  198. hidden_states = self.input_layernorm(hidden_states)
  199. hidden_states_attention, _ = self.self_attn(
  200. hidden_states=hidden_states,
  201. position_embeddings=position_embeddings,
  202. attention_mask=attention_mask,
  203. past_key_values=past_key_values,
  204. use_cache=use_cache,
  205. **kwargs,
  206. )
  207. hidden_states_mlp = self.mlp(hidden_states)
  208. hidden_states = residual + hidden_states_attention + hidden_states_mlp
  209. return hidden_states
  210. class Cohere2PreTrainedModel(CoherePreTrainedModel):
  211. config: Cohere2Config
  212. _can_record_outputs = {
  213. "hidden_states": Cohere2DecoderLayer,
  214. "attentions": Cohere2Attention,
  215. }
  216. class Cohere2Model(Gemma2Model):
  217. def __init__(self, config: Cohere2Config):
  218. super().__init__(config)
  219. self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
  220. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  221. def forward(
  222. self,
  223. input_ids: torch.LongTensor | None = None,
  224. attention_mask: torch.Tensor | None = None,
  225. position_ids: torch.LongTensor | None = None,
  226. past_key_values: Cache | None = None,
  227. inputs_embeds: torch.FloatTensor | None = None,
  228. use_cache: bool | None = None,
  229. **kwargs: Unpack[TransformersKwargs],
  230. ) -> BaseModelOutputWithPast:
  231. if (input_ids is None) ^ (inputs_embeds is not None):
  232. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  233. if inputs_embeds is None:
  234. inputs_embeds = self.embed_tokens(input_ids)
  235. if use_cache and past_key_values is None:
  236. past_key_values = DynamicCache(config=self.config)
  237. if position_ids is None:
  238. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  239. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  240. position_ids = position_ids.unsqueeze(0)
  241. if not isinstance(causal_mask_mapping := attention_mask, dict):
  242. mask_kwargs = {
  243. "config": self.config,
  244. "inputs_embeds": inputs_embeds,
  245. "attention_mask": attention_mask,
  246. "past_key_values": past_key_values,
  247. "position_ids": position_ids,
  248. }
  249. causal_mask_mapping = {
  250. "full_attention": create_causal_mask(**mask_kwargs),
  251. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  252. }
  253. hidden_states = inputs_embeds
  254. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  255. for i, decoder_layer in enumerate(self.layers):
  256. hidden_states = decoder_layer(
  257. hidden_states,
  258. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  259. position_embeddings=position_embeddings,
  260. past_key_values=past_key_values,
  261. use_cache=use_cache,
  262. position_ids=position_ids,
  263. **kwargs,
  264. )
  265. hidden_states = self.norm(hidden_states)
  266. return BaseModelOutputWithPast(
  267. last_hidden_state=hidden_states,
  268. past_key_values=past_key_values,
  269. )
  270. class Cohere2ForCausalLM(CohereForCausalLM):
  271. pass
  272. __all__ = ["Cohere2Config", "Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]