| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- # Copyright 2025
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from huggingface_hub.dataclasses import strict
- from ...cache_utils import Cache, DynamicCache
- from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
- from ...modeling_outputs import BaseModelOutputWithPast
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, logging
- from ..llama.configuration_llama import LlamaConfig
- from ..llama.modeling_llama import (
- LlamaDecoderLayer,
- LlamaForCausalLM,
- LlamaModel,
- LlamaPreTrainedModel,
- )
- from ..qwen2.modeling_qwen2 import Qwen2Attention, Qwen2RotaryEmbedding
- logger = logging.get_logger(__name__)
- @auto_docstring(checkpoint="facebook/cwm")
- @strict
- class CwmConfig(LlamaConfig):
- model_type = "cwm"
- default_theta = 1_000_000.0
- vocab_size: int = 128256
- hidden_size: int = 6144
- intermediate_size: int = 21504
- num_hidden_layers: int = 64
- num_attention_heads: int = 48
- num_key_value_heads: int = 8
- head_dim: int = 128
- hidden_act: str = "silu"
- max_position_embeddings: int = 131072
- initializer_range: float = 0.02
- rms_norm_eps: float = 1e-5
- use_cache: bool = True
- pad_token_id: int | None = None
- eos_token_id: int | list[int] | None = None
- bos_token_id: int = 128000
- tie_word_embeddings: bool = False
- attention_dropout: float | int = 0.0
- pretraining_tp: int = 1
- mlp_bias: bool = False
- rope_parameters: dict | None = None
- sliding_window: int = 8192
- layer_types: list[str] | None = None # ["full_attention"|"sliding_attention"] per layer
- attention_bias = AttributeError()
- def __post_init__(self, **kwargs):
- if self.rope_parameters is None:
- self.rope_parameters = {
- "rope_theta": 1_000_000.0,
- "factor": 16.0,
- "high_freq_factor": 4.0,
- "low_freq_factor": 1.0,
- "original_max_position_embeddings": 8192,
- "rope_type": "llama3",
- }
- if self.layer_types is None:
- # Default pattern: every 4th layer uses full attention, others use sliding attention
- window_pattern = 4
- self.layer_types = [
- ("full_attention" if (i % window_pattern == 0) else "sliding_attention")
- for i in range(self.num_hidden_layers)
- ]
- self.sliding_window = int(self.sliding_window) if self.sliding_window else None
- self.layer_types = list(self.layer_types)
- self.eos_token_id = self.eos_token_id if self.eos_token_id is not None else [128001, 128008, 128009]
- super().__post_init__(**kwargs)
- class CwmRotaryEmbedding(Qwen2RotaryEmbedding):
- pass
- class CwmAttention(Qwen2Attention):
- def __init__(self, config: CwmConfig, layer_idx: int):
- super().__init__(config=config, layer_idx=layer_idx)
- self.q_proj = torch.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
- self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
- class CwmDecoderLayer(LlamaDecoderLayer):
- def __init__(self, config: CwmConfig, layer_idx: int):
- super().__init__(config=config, layer_idx=layer_idx)
- self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)
- class CwmPreTrainedModel(LlamaPreTrainedModel):
- pass
- class CwmModelOutputWithPast(BaseModelOutputWithPast):
- pass
- class CwmModel(LlamaModel):
- config_class = CwmConfig
- def __init__(self, config: CwmConfig):
- super().__init__(config)
- self.layers = torch.nn.ModuleList(
- [CwmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> CwmModelOutputWithPast:
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- if not isinstance(causal_mask_mapping := attention_mask, dict):
- mask_kwargs = {
- "config": self.config,
- "inputs_embeds": inputs_embeds,
- "attention_mask": attention_mask,
- "past_key_values": past_key_values,
- "position_ids": position_ids,
- }
- sliding_mask_kwargs = mask_kwargs.copy()
- causal_mask_mapping = {
- "full_attention": create_causal_mask(**mask_kwargs),
- "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
- }
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask_mapping[self.config.layer_types[i]],
- position_ids=position_ids,
- past_key_values=past_key_values,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- return CwmModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- class CwmForCausalLM(LlamaForCausalLM):
- pass
- __all__ = [
- "CwmConfig",
- "CwmPreTrainedModel",
- "CwmModel",
- "CwmForCausalLM",
- ]
|