modular_lfm2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. import torch
  16. import torch.nn.functional as F
  17. from torch import nn
  18. from ...cache_utils import Cache, DynamicCache
  19. from ...masking_utils import create_causal_mask
  20. from ...modeling_layers import GradientCheckpointingLayer
  21. from ...modeling_outputs import BaseModelOutputWithPast
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  23. from ...processing_utils import Unpack
  24. from ...utils import TransformersKwargs, logging
  25. from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling
  26. from ..bamba.modeling_bamba import apply_mask_to_padding_states
  27. from ..gemma2.modeling_gemma2 import Gemma2RotaryEmbedding
  28. from ..llama.modeling_llama import (
  29. LlamaAttention,
  30. LlamaForCausalLM,
  31. LlamaModel,
  32. LlamaPreTrainedModel,
  33. LlamaRMSNorm,
  34. apply_rotary_pos_emb,
  35. eager_attention_forward,
  36. )
  37. from .configuration_lfm2 import Lfm2Config
  38. if is_causal_conv1d_available():
  39. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  40. else:
  41. causal_conv1d_fn, causal_conv1d_update = None, None
  42. kernel_modules = (causal_conv1d_fn, causal_conv1d_update)
  43. is_fast_path_available = all(kernel_modules)
  44. logger = logging.get_logger(__name__)
  45. class Lfm2RMSNorm(LlamaRMSNorm):
  46. pass
  47. class Lfm2RotaryEmbedding(Gemma2RotaryEmbedding):
  48. pass
  49. class Lfm2MLP(nn.Module):
  50. def __init__(self, config: Lfm2Config):
  51. super().__init__()
  52. intermediate_size = config.intermediate_size
  53. if config.block_auto_adjust_ff_dim:
  54. intermediate_size = int(2 * intermediate_size / 3)
  55. # custom dim factor multiplier
  56. if config.block_ffn_dim_multiplier is not None:
  57. intermediate_size = int(config.block_ffn_dim_multiplier * intermediate_size)
  58. intermediate_size = config.block_multiple_of * (
  59. (intermediate_size + config.block_multiple_of - 1) // config.block_multiple_of
  60. )
  61. self.w1 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  62. self.w3 = nn.Linear(config.hidden_size, intermediate_size, bias=False)
  63. self.w2 = nn.Linear(intermediate_size, config.hidden_size, bias=False)
  64. def forward(self, x):
  65. return self.w2(F.silu(self.w1(x)) * self.w3(x))
  66. class Lfm2Attention(LlamaAttention):
  67. def __init__(self, config: Lfm2Config, layer_idx: int):
  68. super().__init__(config, layer_idx)
  69. self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  70. self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  71. self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
  72. self.out_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
  73. self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  74. self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps)
  75. del self.o_proj
  76. del self.attention_dropout
  77. def forward(
  78. self,
  79. hidden_states: torch.Tensor,
  80. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  81. attention_mask: torch.Tensor | None,
  82. past_key_values: Cache | None = None,
  83. **kwargs,
  84. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  85. input_shape = hidden_states.shape[:-1]
  86. hidden_shape = (*input_shape, -1, self.head_dim)
  87. query_states = self.q_layernorm(self.q_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  88. key_states = self.k_layernorm(self.k_proj(hidden_states).view(*hidden_shape)).transpose(1, 2)
  89. value_states = self.v_proj(hidden_states).view(*hidden_shape).transpose(1, 2)
  90. cos, sin = position_embeddings
  91. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  92. if past_key_values is not None:
  93. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  94. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  95. self.config._attn_implementation, eager_attention_forward
  96. )
  97. attn_output, attn_weights = attention_interface(
  98. self,
  99. query_states,
  100. key_states,
  101. value_states,
  102. attention_mask,
  103. dropout=0.0,
  104. scaling=self.scaling,
  105. **kwargs,
  106. )
  107. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  108. output = self.out_proj(attn_output)
  109. return output, attn_weights
  110. class Lfm2ShortConv(nn.Module):
  111. def __init__(
  112. self,
  113. config: Lfm2Config,
  114. layer_idx: int,
  115. ):
  116. super().__init__()
  117. self.config = config
  118. self.layer_idx = layer_idx
  119. self.L_cache = config.conv_L_cache
  120. self.bias = config.conv_bias
  121. self.conv = nn.Conv1d(
  122. in_channels=config.hidden_size,
  123. out_channels=config.hidden_size,
  124. kernel_size=self.L_cache,
  125. groups=config.hidden_size,
  126. bias=self.bias,
  127. padding=self.L_cache - 1,
  128. )
  129. self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias)
  130. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias)
  131. def cuda_kernels_forward(
  132. self,
  133. x: torch.Tensor,
  134. past_key_values: Cache | None = None,
  135. attention_mask: torch.Tensor | None = None,
  136. ):
  137. x = apply_mask_to_padding_states(x, attention_mask)
  138. BCx = self.in_proj(x).transpose(-1, -2)
  139. B, C, x = BCx.chunk(3, dim=-2)
  140. Bx = B * x
  141. conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
  142. if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx):
  143. conv_out = causal_conv1d_update(
  144. Bx.squeeze(-1),
  145. past_key_values.layers[self.layer_idx].conv_states,
  146. conv_weights,
  147. self.conv.bias,
  148. None,
  149. )
  150. conv_out = conv_out.unsqueeze(-1)
  151. else:
  152. if past_key_values is not None:
  153. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  154. conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx)
  155. conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None)
  156. y = C * conv_out
  157. y = self.out_proj(y.transpose(-1, -2).contiguous())
  158. return y
  159. def slow_forward(
  160. self,
  161. x: torch.Tensor,
  162. past_key_values: Cache | None = None,
  163. attention_mask: torch.Tensor | None = None,
  164. ):
  165. seqlen = x.shape[1]
  166. x = apply_mask_to_padding_states(x, attention_mask)
  167. BCx = self.in_proj(x).transpose(-1, -2)
  168. B, C, x = BCx.chunk(3, dim=-2)
  169. Bx = B * x
  170. if past_key_values is not None and past_key_values.has_previous_state(self.layer_idx):
  171. conv_state = past_key_values.update_conv_state(Bx, self.layer_idx)
  172. conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1)
  173. if self.bias:
  174. conv_out += self.conv.bias
  175. conv_out = conv_out.unsqueeze(-1)
  176. else:
  177. if past_key_values is not None:
  178. conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0))
  179. conv_state = past_key_values.update_conv_state(conv_state, self.layer_idx)
  180. conv_out = self.conv(Bx)[..., :seqlen]
  181. y = C * conv_out
  182. y = y.transpose(-1, -2).contiguous()
  183. y = self.out_proj(y)
  184. return y
  185. def forward(
  186. self,
  187. hidden_states: torch.Tensor,
  188. past_key_values: Cache | None = None,
  189. attention_mask: torch.Tensor | None = None,
  190. ):
  191. if is_fast_path_available and "cuda" in hidden_states.device.type and not is_torchdynamo_compiling():
  192. return self.cuda_kernels_forward(hidden_states, past_key_values, attention_mask)
  193. return self.slow_forward(hidden_states, past_key_values, attention_mask)
  194. class Lfm2DecoderLayer(GradientCheckpointingLayer):
  195. def __init__(self, config: Lfm2Config, layer_idx: int):
  196. super().__init__()
  197. self.is_attention_layer = config.layer_types[layer_idx] == "full_attention"
  198. if self.is_attention_layer:
  199. self.self_attn = Lfm2Attention(config, layer_idx)
  200. else:
  201. self.conv = Lfm2ShortConv(config, layer_idx)
  202. self.feed_forward = Lfm2MLP(config)
  203. self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  204. self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  205. def forward(
  206. self,
  207. hidden_states: torch.Tensor,
  208. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  209. attention_mask: torch.Tensor | None = None,
  210. position_ids: torch.LongTensor | None = None,
  211. past_key_values: Cache | None = None,
  212. **kwargs,
  213. ) -> torch.Tensor:
  214. residual = hidden_states
  215. if self.is_attention_layer:
  216. hidden_states, _ = self.self_attn(
  217. hidden_states=self.operator_norm(hidden_states),
  218. position_embeddings=position_embeddings,
  219. attention_mask=attention_mask,
  220. position_ids=position_ids,
  221. past_key_values=past_key_values,
  222. **kwargs,
  223. )
  224. else:
  225. hidden_states = self.conv(
  226. hidden_states=self.operator_norm(hidden_states),
  227. past_key_values=past_key_values,
  228. attention_mask=attention_mask,
  229. )
  230. hidden_states = hidden_states + residual
  231. hidden_states = hidden_states + self.feed_forward(self.ffn_norm(hidden_states))
  232. return hidden_states
  233. class Lfm2PreTrainedModel(LlamaPreTrainedModel):
  234. _can_compile_fullgraph = False
  235. class Lfm2Model(LlamaModel):
  236. def __init__(self, config: Lfm2Config):
  237. super().__init__(config)
  238. self.embedding_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps)
  239. del self.norm
  240. def forward(
  241. self,
  242. input_ids: torch.LongTensor | None = None,
  243. attention_mask: torch.Tensor | None = None,
  244. position_ids: torch.LongTensor | None = None,
  245. past_key_values: Cache | None = None,
  246. inputs_embeds: torch.FloatTensor | None = None,
  247. use_cache: bool | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> BaseModelOutputWithPast:
  250. if (input_ids is None) ^ (inputs_embeds is not None):
  251. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  252. if inputs_embeds is None:
  253. inputs_embeds = self.embed_tokens(input_ids)
  254. if use_cache and past_key_values is None:
  255. past_key_values = DynamicCache(config=self.config)
  256. if position_ids is None:
  257. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  258. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  259. position_ids = position_ids.unsqueeze(0)
  260. causal_mask = create_causal_mask(
  261. config=self.config,
  262. inputs_embeds=inputs_embeds,
  263. attention_mask=attention_mask,
  264. past_key_values=past_key_values,
  265. position_ids=position_ids,
  266. )
  267. # Skip masking for decoding stage. We check shape here to be compile-friendly
  268. linear_attention = attention_mask if inputs_embeds.shape[1] != 1 else None
  269. hidden_states = inputs_embeds
  270. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  271. # decoder layers
  272. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  273. layer_mask = causal_mask if self.config.layer_types[i] == "full_attention" else linear_attention
  274. hidden_states = decoder_layer(
  275. hidden_states,
  276. attention_mask=layer_mask,
  277. position_embeddings=position_embeddings,
  278. position_ids=position_ids,
  279. past_key_values=past_key_values,
  280. **kwargs,
  281. )
  282. hidden_states = self.embedding_norm(hidden_states)
  283. return BaseModelOutputWithPast(
  284. last_hidden_state=hidden_states,
  285. past_key_values=past_key_values,
  286. )
  287. class Lfm2ForCausalLM(LlamaForCausalLM):
  288. pass
  289. __all__ = ["Lfm2ForCausalLM", "Lfm2Model", "Lfm2PreTrainedModel"]