modular_cwm.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # Copyright 2025
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from huggingface_hub.dataclasses import strict
  16. from ...cache_utils import Cache, DynamicCache
  17. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  18. from ...modeling_outputs import BaseModelOutputWithPast
  19. from ...processing_utils import Unpack
  20. from ...utils import TransformersKwargs, auto_docstring, logging
  21. from ..llama.configuration_llama import LlamaConfig
  22. from ..llama.modeling_llama import (
  23. LlamaDecoderLayer,
  24. LlamaForCausalLM,
  25. LlamaModel,
  26. LlamaPreTrainedModel,
  27. )
  28. from ..qwen2.modeling_qwen2 import Qwen2Attention, Qwen2RotaryEmbedding
  29. logger = logging.get_logger(__name__)
  30. @auto_docstring(checkpoint="facebook/cwm")
  31. @strict
  32. class CwmConfig(LlamaConfig):
  33. model_type = "cwm"
  34. default_theta = 1_000_000.0
  35. vocab_size: int = 128256
  36. hidden_size: int = 6144
  37. intermediate_size: int = 21504
  38. num_hidden_layers: int = 64
  39. num_attention_heads: int = 48
  40. num_key_value_heads: int = 8
  41. head_dim: int = 128
  42. hidden_act: str = "silu"
  43. max_position_embeddings: int = 131072
  44. initializer_range: float = 0.02
  45. rms_norm_eps: float = 1e-5
  46. use_cache: bool = True
  47. pad_token_id: int | None = None
  48. eos_token_id: int | list[int] | None = None
  49. bos_token_id: int = 128000
  50. tie_word_embeddings: bool = False
  51. attention_dropout: float | int = 0.0
  52. pretraining_tp: int = 1
  53. mlp_bias: bool = False
  54. rope_parameters: dict | None = None
  55. sliding_window: int = 8192
  56. layer_types: list[str] | None = None # ["full_attention"|"sliding_attention"] per layer
  57. attention_bias = AttributeError()
  58. def __post_init__(self, **kwargs):
  59. if self.rope_parameters is None:
  60. self.rope_parameters = {
  61. "rope_theta": 1_000_000.0,
  62. "factor": 16.0,
  63. "high_freq_factor": 4.0,
  64. "low_freq_factor": 1.0,
  65. "original_max_position_embeddings": 8192,
  66. "rope_type": "llama3",
  67. }
  68. if self.layer_types is None:
  69. # Default pattern: every 4th layer uses full attention, others use sliding attention
  70. window_pattern = 4
  71. self.layer_types = [
  72. ("full_attention" if (i % window_pattern == 0) else "sliding_attention")
  73. for i in range(self.num_hidden_layers)
  74. ]
  75. self.sliding_window = int(self.sliding_window) if self.sliding_window else None
  76. self.layer_types = list(self.layer_types)
  77. self.eos_token_id = self.eos_token_id if self.eos_token_id is not None else [128001, 128008, 128009]
  78. super().__post_init__(**kwargs)
  79. class CwmRotaryEmbedding(Qwen2RotaryEmbedding):
  80. pass
  81. class CwmAttention(Qwen2Attention):
  82. def __init__(self, config: CwmConfig, layer_idx: int):
  83. super().__init__(config=config, layer_idx=layer_idx)
  84. self.q_proj = torch.nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  85. self.k_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  86. self.v_proj = torch.nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  87. class CwmDecoderLayer(LlamaDecoderLayer):
  88. def __init__(self, config: CwmConfig, layer_idx: int):
  89. super().__init__(config=config, layer_idx=layer_idx)
  90. self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)
  91. class CwmPreTrainedModel(LlamaPreTrainedModel):
  92. pass
  93. class CwmModelOutputWithPast(BaseModelOutputWithPast):
  94. pass
  95. class CwmModel(LlamaModel):
  96. config_class = CwmConfig
  97. def __init__(self, config: CwmConfig):
  98. super().__init__(config)
  99. self.layers = torch.nn.ModuleList(
  100. [CwmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  101. )
  102. def forward(
  103. self,
  104. input_ids: torch.LongTensor | None = None,
  105. attention_mask: torch.Tensor | None = None,
  106. position_ids: torch.LongTensor | None = None,
  107. past_key_values: Cache | None = None,
  108. inputs_embeds: torch.FloatTensor | None = None,
  109. use_cache: bool | None = None,
  110. **kwargs: Unpack[TransformersKwargs],
  111. ) -> CwmModelOutputWithPast:
  112. if (input_ids is None) ^ (inputs_embeds is not None):
  113. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  114. if inputs_embeds is None:
  115. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  116. if use_cache and past_key_values is None:
  117. past_key_values = DynamicCache(config=self.config)
  118. if position_ids is None:
  119. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  120. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  121. position_ids = position_ids.unsqueeze(0)
  122. if not isinstance(causal_mask_mapping := attention_mask, dict):
  123. mask_kwargs = {
  124. "config": self.config,
  125. "inputs_embeds": inputs_embeds,
  126. "attention_mask": attention_mask,
  127. "past_key_values": past_key_values,
  128. "position_ids": position_ids,
  129. }
  130. sliding_mask_kwargs = mask_kwargs.copy()
  131. causal_mask_mapping = {
  132. "full_attention": create_causal_mask(**mask_kwargs),
  133. "sliding_attention": create_sliding_window_causal_mask(**sliding_mask_kwargs),
  134. }
  135. hidden_states = inputs_embeds
  136. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  137. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  138. hidden_states = decoder_layer(
  139. hidden_states,
  140. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  141. position_ids=position_ids,
  142. past_key_values=past_key_values,
  143. position_embeddings=position_embeddings,
  144. **kwargs,
  145. )
  146. hidden_states = self.norm(hidden_states)
  147. return CwmModelOutputWithPast(
  148. last_hidden_state=hidden_states,
  149. past_key_values=past_key_values,
  150. )
  151. class CwmForCausalLM(LlamaForCausalLM):
  152. pass
  153. __all__ = [
  154. "CwmConfig",
  155. "CwmPreTrainedModel",
  156. "CwmModel",
  157. "CwmForCausalLM",
  158. ]