modeling_cwm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/cwm/modular_cwm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_cwm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...generation import GenerationMixin
  27. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  28. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  32. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  36. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  37. from ...utils.output_capturing import capture_outputs
  38. from .configuration_cwm import CwmConfig
  39. class CwmRotaryEmbedding(nn.Module):
  40. inv_freq: torch.Tensor # fix linting for `register_buffer`
  41. def __init__(self, config: CwmConfig, device=None):
  42. super().__init__()
  43. self.max_seq_len_cached = config.max_position_embeddings
  44. self.original_max_seq_len = config.max_position_embeddings
  45. self.config = config
  46. self.rope_type = self.config.rope_parameters["rope_type"]
  47. rope_init_fn: Callable = self.compute_default_rope_parameters
  48. if self.rope_type != "default":
  49. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  50. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  51. self.register_buffer("inv_freq", inv_freq, persistent=False)
  52. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  53. @staticmethod
  54. def compute_default_rope_parameters(
  55. config: CwmConfig | None = None,
  56. device: Optional["torch.device"] = None,
  57. seq_len: int | None = None,
  58. ) -> tuple["torch.Tensor", float]:
  59. """
  60. Computes the inverse frequencies according to the original RoPE implementation
  61. Args:
  62. config ([`~transformers.PreTrainedConfig`]):
  63. The model configuration.
  64. device (`torch.device`):
  65. The device to use for initialization of the inverse frequencies.
  66. seq_len (`int`, *optional*):
  67. The current sequence length. Unused for this type of RoPE.
  68. Returns:
  69. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  70. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  71. """
  72. base = config.rope_parameters["rope_theta"]
  73. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  74. attention_factor = 1.0 # Unused in this type of RoPE
  75. # Compute the inverse frequencies
  76. inv_freq = 1.0 / (
  77. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  78. )
  79. return inv_freq, attention_factor
  80. @torch.no_grad()
  81. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  82. def forward(self, x, position_ids):
  83. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  84. position_ids_expanded = position_ids[:, None, :].float()
  85. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  86. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  87. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  88. emb = torch.cat((freqs, freqs), dim=-1)
  89. cos = emb.cos() * self.attention_scaling
  90. sin = emb.sin() * self.attention_scaling
  91. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  92. def rotate_half(x):
  93. """Rotates half the hidden dims of the input."""
  94. x1 = x[..., : x.shape[-1] // 2]
  95. x2 = x[..., x.shape[-1] // 2 :]
  96. return torch.cat((-x2, x1), dim=-1)
  97. @use_kernel_func_from_hub("rotary_pos_emb")
  98. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  99. """Applies Rotary Position Embedding to the query and key tensors.
  100. Args:
  101. q (`torch.Tensor`): The query tensor.
  102. k (`torch.Tensor`): The key tensor.
  103. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  104. sin (`torch.Tensor`): The sine part of the rotary embedding.
  105. unsqueeze_dim (`int`, *optional*, defaults to 1):
  106. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  107. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  108. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  109. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  110. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  111. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  112. Returns:
  113. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  114. """
  115. cos = cos.unsqueeze(unsqueeze_dim)
  116. sin = sin.unsqueeze(unsqueeze_dim)
  117. q_embed = (q * cos) + (rotate_half(q) * sin)
  118. k_embed = (k * cos) + (rotate_half(k) * sin)
  119. return q_embed, k_embed
  120. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  121. """
  122. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  123. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  124. """
  125. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  126. if n_rep == 1:
  127. return hidden_states
  128. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  129. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  130. def eager_attention_forward(
  131. module: nn.Module,
  132. query: torch.Tensor,
  133. key: torch.Tensor,
  134. value: torch.Tensor,
  135. attention_mask: torch.Tensor | None,
  136. scaling: float,
  137. dropout: float = 0.0,
  138. **kwargs: Unpack[TransformersKwargs],
  139. ):
  140. key_states = repeat_kv(key, module.num_key_value_groups)
  141. value_states = repeat_kv(value, module.num_key_value_groups)
  142. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  143. if attention_mask is not None:
  144. attn_weights = attn_weights + attention_mask
  145. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  146. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  147. attn_output = torch.matmul(attn_weights, value_states)
  148. attn_output = attn_output.transpose(1, 2).contiguous()
  149. return attn_output, attn_weights
  150. @use_kernelized_func(apply_rotary_pos_emb)
  151. class CwmAttention(nn.Module):
  152. """Multi-headed attention from 'Attention Is All You Need' paper"""
  153. def __init__(self, config: CwmConfig, layer_idx: int):
  154. super().__init__()
  155. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  156. self.config = config
  157. self.layer_idx = layer_idx
  158. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  159. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  160. self.scaling = self.head_dim**-0.5
  161. self.attention_dropout = config.attention_dropout
  162. self.is_causal = True
  163. self.q_proj = torch.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  164. self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  165. self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  166. self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  167. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  168. def forward(
  169. self,
  170. hidden_states: torch.Tensor,
  171. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  172. attention_mask: torch.Tensor | None,
  173. past_key_values: Cache | None = None,
  174. **kwargs: Unpack[FlashAttentionKwargs],
  175. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  176. input_shape = hidden_states.shape[:-1]
  177. hidden_shape = (*input_shape, -1, self.head_dim)
  178. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  179. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  180. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  181. cos, sin = position_embeddings
  182. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  183. if past_key_values is not None:
  184. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  185. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  186. self.config._attn_implementation, eager_attention_forward
  187. )
  188. attn_output, attn_weights = attention_interface(
  189. self,
  190. query_states,
  191. key_states,
  192. value_states,
  193. attention_mask,
  194. dropout=0.0 if not self.training else self.attention_dropout,
  195. scaling=self.scaling,
  196. sliding_window=self.sliding_window, # main diff with Llama
  197. **kwargs,
  198. )
  199. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  200. attn_output = self.o_proj(attn_output)
  201. return attn_output, attn_weights
  202. @use_kernel_forward_from_hub("RMSNorm")
  203. class CwmRMSNorm(nn.Module):
  204. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  205. """
  206. CwmRMSNorm is equivalent to T5LayerNorm
  207. """
  208. super().__init__()
  209. self.weight = nn.Parameter(torch.ones(hidden_size))
  210. self.variance_epsilon = eps
  211. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  212. input_dtype = hidden_states.dtype
  213. hidden_states = hidden_states.to(torch.float32)
  214. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  215. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  216. return self.weight * hidden_states.to(input_dtype)
  217. def extra_repr(self):
  218. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  219. class CwmMLP(nn.Module):
  220. def __init__(self, config):
  221. super().__init__()
  222. self.config = config
  223. self.hidden_size = config.hidden_size
  224. self.intermediate_size = config.intermediate_size
  225. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  226. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  227. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  228. self.act_fn = ACT2FN[config.hidden_act]
  229. def forward(self, x):
  230. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  231. return down_proj
  232. class CwmDecoderLayer(GradientCheckpointingLayer):
  233. def __init__(self, config: CwmConfig, layer_idx: int):
  234. super().__init__()
  235. self.hidden_size = config.hidden_size
  236. self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)
  237. self.mlp = CwmMLP(config)
  238. self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  239. self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  240. def forward(
  241. self,
  242. hidden_states: torch.Tensor,
  243. attention_mask: torch.Tensor | None = None,
  244. position_ids: torch.LongTensor | None = None,
  245. past_key_values: Cache | None = None,
  246. use_cache: bool | None = False,
  247. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> torch.Tensor:
  250. residual = hidden_states
  251. hidden_states = self.input_layernorm(hidden_states)
  252. # Self Attention
  253. hidden_states, _ = self.self_attn(
  254. hidden_states=hidden_states,
  255. attention_mask=attention_mask,
  256. position_ids=position_ids,
  257. past_key_values=past_key_values,
  258. use_cache=use_cache,
  259. position_embeddings=position_embeddings,
  260. **kwargs,
  261. )
  262. hidden_states = residual + hidden_states
  263. # Fully Connected
  264. residual = hidden_states
  265. hidden_states = self.post_attention_layernorm(hidden_states)
  266. hidden_states = self.mlp(hidden_states)
  267. hidden_states = residual + hidden_states
  268. return hidden_states
  269. @auto_docstring
  270. class CwmPreTrainedModel(PreTrainedModel):
  271. config: CwmConfig
  272. base_model_prefix = "model"
  273. supports_gradient_checkpointing = True
  274. _no_split_modules = ["CwmDecoderLayer"]
  275. _skip_keys_device_placement = ["past_key_values"]
  276. _supports_flash_attn = True
  277. _supports_sdpa = True
  278. _supports_flex_attn = True
  279. _can_compile_fullgraph = True
  280. _supports_attention_backend = True
  281. _can_record_outputs = {
  282. "hidden_states": CwmDecoderLayer,
  283. "attentions": CwmAttention,
  284. }
  285. class CwmModelOutputWithPast(BaseModelOutputWithPast):
  286. pass
  287. @auto_docstring
  288. class CwmModel(CwmPreTrainedModel):
  289. config_class = CwmConfig
  290. def __init__(self, config: CwmConfig):
  291. super().__init__(config)
  292. self.padding_idx = config.pad_token_id
  293. self.vocab_size = config.vocab_size
  294. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  295. self.layers = torch.nn.ModuleList(
  296. [CwmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  297. )
  298. self.norm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  299. self.rotary_emb = CwmRotaryEmbedding(config=config)
  300. self.gradient_checkpointing = False
  301. # Initialize weights and apply final processing
  302. self.post_init()
  303. @merge_with_config_defaults
  304. @capture_outputs
  305. @auto_docstring
  306. def forward(
  307. self,
  308. input_ids: torch.LongTensor | None = None,
  309. attention_mask: torch.Tensor | None = None,
  310. position_ids: torch.LongTensor | None = None,
  311. past_key_values: Cache | None = None,
  312. inputs_embeds: torch.FloatTensor | None = None,
  313. use_cache: bool | None = None,
  314. **kwargs: Unpack[TransformersKwargs],
  315. ) -> CwmModelOutputWithPast:
  316. if (input_ids is None) ^ (inputs_embeds is not None):
  317. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  318. if inputs_embeds is None:
  319. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  320. if use_cache and past_key_values is None:
  321. past_key_values = DynamicCache(config=self.config)
  322. if position_ids is None:
  323. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  324. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  325. position_ids = position_ids.unsqueeze(0)
  326. if not isinstance(causal_mask_mapping := attention_mask, dict):
  327. mask_kwargs = {
  328. "config": self.config,
  329. "inputs_embeds": inputs_embeds,
  330. "attention_mask": attention_mask,
  331. "past_key_values": past_key_values,
  332. "position_ids": position_ids,
  333. }
  334. sliding_mask_kwargs = mask_kwargs.copy()
  335. causal_mask_mapping = {
  336. "full_attention": create_causal_mask(**mask_kwargs),
  337. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  338. }
  339. hidden_states = inputs_embeds
  340. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  341. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  342. hidden_states = decoder_layer(
  343. hidden_states,
  344. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  345. position_ids=position_ids,
  346. past_key_values=past_key_values,
  347. position_embeddings=position_embeddings,
  348. **kwargs,
  349. )
  350. hidden_states = self.norm(hidden_states)
  351. return CwmModelOutputWithPast(
  352. last_hidden_state=hidden_states,
  353. past_key_values=past_key_values,
  354. )
  355. @auto_docstring
  356. class CwmForCausalLM(CwmPreTrainedModel, GenerationMixin):
  357. _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
  358. _tp_plan = {"lm_head": "colwise_gather_output"}
  359. _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
  360. def __init__(self, config):
  361. super().__init__(config)
  362. self.model = CwmModel(config)
  363. self.vocab_size = config.vocab_size
  364. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  365. # Initialize weights and apply final processing
  366. self.post_init()
  367. @can_return_tuple
  368. @auto_docstring
  369. def forward(
  370. self,
  371. input_ids: torch.LongTensor | None = None,
  372. attention_mask: torch.Tensor | None = None,
  373. position_ids: torch.LongTensor | None = None,
  374. past_key_values: Cache | None = None,
  375. inputs_embeds: torch.FloatTensor | None = None,
  376. labels: torch.LongTensor | None = None,
  377. use_cache: bool | None = None,
  378. logits_to_keep: int | torch.Tensor = 0,
  379. **kwargs: Unpack[TransformersKwargs],
  380. ) -> CausalLMOutputWithPast:
  381. r"""
  382. Example:
  383. ```python
  384. >>> from transformers import AutoTokenizer, CwmForCausalLM
  385. >>> model = CwmForCausalLM.from_pretrained("meta-cwm/Cwm-2-7b-hf")
  386. >>> tokenizer = AutoTokenizer.from_pretrained("meta-cwm/Cwm-2-7b-hf")
  387. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  388. >>> inputs = tokenizer(prompt, return_tensors="pt")
  389. >>> # Generate
  390. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  391. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  392. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  393. ```"""
  394. outputs: BaseModelOutputWithPast = self.model(
  395. input_ids=input_ids,
  396. attention_mask=attention_mask,
  397. position_ids=position_ids,
  398. past_key_values=past_key_values,
  399. inputs_embeds=inputs_embeds,
  400. use_cache=use_cache,
  401. **kwargs,
  402. )
  403. hidden_states = outputs.last_hidden_state
  404. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  405. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  406. logits = self.lm_head(hidden_states[:, slice_indices, :])
  407. loss = None
  408. if labels is not None:
  409. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  410. return CausalLMOutputWithPast(
  411. loss=loss,
  412. logits=logits,
  413. past_key_values=outputs.past_key_values,
  414. hidden_states=outputs.hidden_states,
  415. attentions=outputs.attentions,
  416. )
  417. __all__ = ["CwmPreTrainedModel", "CwmModel", "CwmForCausalLM"]