modular_diffllama.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright 2024 weak-kajuma and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # This code is based on Llama implementations in this library and Microsoft's
  4. # Differential Transformer implementations.
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. import torch
  18. from torch import nn
  19. from ... import initialization as init
  20. from ...cache_utils import Cache, StaticCache
  21. from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import logging
  24. from ..gemma.modeling_gemma import GemmaForCausalLM
  25. from ..llama.modeling_llama import (
  26. LlamaDecoderLayer,
  27. LlamaForQuestionAnswering,
  28. LlamaForSequenceClassification,
  29. LlamaForTokenClassification,
  30. LlamaModel,
  31. LlamaPreTrainedModel,
  32. LlamaRotaryEmbedding,
  33. apply_rotary_pos_emb,
  34. repeat_kv,
  35. )
  36. from ..mistral.modeling_mistral import MistralMLP
  37. from .configuration_diffllama import DiffLlamaConfig
  38. logger = logging.get_logger(__name__)
  39. _CHECKPOINT_FOR_DOC = "kajuma/DiffLlama-0.3B-handcut"
  40. _CONFIG_FOR_DOC = "DiffLlamaConfig"
  41. class DiffLlamaMLP(MistralMLP):
  42. pass
  43. def lambda_init_fn(layer_idx):
  44. return 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
  45. class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
  46. pass
  47. class DiffLlamaAttention(nn.Module):
  48. """Multi-headed attention from 'Attention Is All You Need' paper"""
  49. def __init__(self, config: DiffLlamaConfig, layer_idx: int | None = None):
  50. super().__init__()
  51. self.config = config
  52. self.layer_idx = layer_idx
  53. if layer_idx is None:
  54. logger.warning_once(
  55. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  56. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  57. "when creating this class."
  58. )
  59. self.attention_dropout = config.attention_dropout
  60. self.hidden_size = config.hidden_size
  61. self.num_heads = config.num_attention_heads
  62. self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
  63. self.num_key_value_heads = config.num_key_value_heads
  64. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  65. # under this are not used
  66. self.max_position_embeddings = config.max_position_embeddings
  67. self.is_causal = True
  68. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  69. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  70. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  71. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  72. self.lambda_init = lambda_init_fn(layer_idx)
  73. self.lambda_q1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  74. self.lambda_k1 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  75. self.lambda_q2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  76. self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,)))
  77. self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False)
  78. def forward(
  79. self,
  80. hidden_states: torch.Tensor,
  81. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  82. attention_mask: torch.Tensor | None = None,
  83. position_ids: torch.LongTensor | None = None,
  84. past_key_values: Cache | None = None,
  85. use_cache: bool = False,
  86. **kwargs,
  87. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  88. bsz, target_len, _ = hidden_states.size()
  89. q_len = target_len
  90. query_states = self.q_proj(hidden_states)
  91. key_states = self.k_proj(hidden_states)
  92. value_states = self.v_proj(hidden_states)
  93. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  94. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  95. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  96. cos, sin = position_embeddings
  97. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  98. if past_key_values is not None:
  99. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  100. key_states = repeat_kv(key_states, self.num_key_value_groups)
  101. value_states = repeat_kv(value_states, self.num_key_value_groups)
  102. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  103. value_states = value_states.repeat(1, 2, 1, 1)
  104. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  105. if attention_mask is not None:
  106. attn_weights = attn_weights + attention_mask
  107. # upcast attention to fp32
  108. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  109. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  110. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  111. query_states.dtype
  112. )
  113. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  114. query_states.dtype
  115. )
  116. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  117. attn_output = torch.matmul(attn_weights, value_states)
  118. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  119. attn_output = attn_output1 - lambda_full * attn_output2
  120. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  121. attn_output = attn_output.transpose(1, 2).contiguous()
  122. attn_output = attn_output.reshape(bsz, q_len, -1)
  123. attn_output = self.o_proj(attn_output)
  124. return attn_output, attn_weights
  125. class DiffLlamaFlashAttention2(DiffLlamaAttention):
  126. """
  127. DiffLlama flash attention module. This module inherits from `DiffLlamaAttention` as the weights of the module stays
  128. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  129. flash attention and deal with padding tokens in case the input contains any of them.
  130. """
  131. def __init__(self, *args, **kwargs):
  132. super().__init__(*args, **kwargs)
  133. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  134. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  135. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  136. self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
  137. def forward(
  138. self,
  139. hidden_states: torch.Tensor,
  140. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  141. attention_mask: torch.LongTensor | None = None,
  142. position_ids: torch.LongTensor | None = None,
  143. past_key_values: Cache | None = None,
  144. use_cache: bool = False,
  145. ) -> tuple[torch.Tensor, None]:
  146. if isinstance(past_key_values, StaticCache):
  147. raise ValueError(
  148. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  149. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  150. )
  151. bsz, q_len, _ = hidden_states.size()
  152. query_states = self.q_proj(hidden_states)
  153. key_states = self.k_proj(hidden_states)
  154. value_states = self.v_proj(hidden_states)
  155. # Flash attention requires the input to have the shape
  156. # batch_size x seq_length x head_dim x hidden_dim
  157. # therefore we just need to keep the original shape
  158. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  159. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  160. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  161. cos, sin = position_embeddings
  162. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  163. if past_key_values is not None:
  164. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  165. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  166. # to be able to avoid many of these transpose/reshape/view.
  167. query_states = query_states.transpose(1, 2)
  168. key_states = key_states.transpose(1, 2)
  169. value_states = value_states.transpose(1, 2)
  170. dropout_rate = self.attention_dropout if self.training else 0.0
  171. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  172. # therefore the input hidden states gets silently casted in float32. Hence, we need
  173. # cast them back in the correct dtype just to be sure everything works as expected.
  174. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  175. # in fp32. (DiffLlamaRMSNorm handles it correctly)
  176. input_dtype = query_states.dtype
  177. device_type = query_states.device.type if query_states.device.type != "mps" else "cpu"
  178. if input_dtype == torch.float32:
  179. if torch.is_autocast_enabled(device_type):
  180. target_dtype = torch.get_autocast_dtype(device_type)
  181. # Handle the case where the model is quantized
  182. elif hasattr(self.config, "_is_quantized"):
  183. target_dtype = self.config.dtype
  184. else:
  185. target_dtype = self.q_proj.weight.dtype
  186. logger.warning_once(
  187. f"The input hidden states seems to be silently casted in float32, this might be related to"
  188. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  189. f" {target_dtype}."
  190. )
  191. query_states = query_states.to(target_dtype)
  192. key_states = key_states.to(target_dtype)
  193. value_states = value_states.to(target_dtype)
  194. value_states1, value_states2 = torch.chunk(value_states, 2, dim=2)
  195. value_states1 = value_states1.repeat(1, 1, 2, 1)
  196. value_states2 = value_states2.repeat(1, 1, 2, 1)
  197. attn_output1 = _flash_attention_forward(
  198. query_states,
  199. key_states,
  200. value_states1,
  201. attention_mask,
  202. q_len,
  203. position_ids=position_ids,
  204. dropout=dropout_rate,
  205. sliding_window=getattr(self, "sliding_window", None),
  206. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  207. is_causal=self.is_causal,
  208. )
  209. attn_output2 = _flash_attention_forward(
  210. query_states,
  211. key_states,
  212. value_states2,
  213. attention_mask,
  214. q_len,
  215. position_ids=position_ids,
  216. dropout=dropout_rate,
  217. sliding_window=getattr(self, "sliding_window", None),
  218. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  219. is_causal=self.is_causal,
  220. )
  221. attn_output = torch.cat([attn_output1, attn_output2], dim=-1)
  222. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=2)
  223. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  224. query_states.dtype
  225. )
  226. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  227. query_states.dtype
  228. )
  229. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  230. attn_output = attn_output1 - lambda_full * attn_output2
  231. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  232. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  233. attn_output = self.o_proj(attn_output)
  234. return attn_output, None
  235. class DiffLlamaSdpaAttention(DiffLlamaAttention):
  236. """
  237. DiffLlama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  238. `DiffLlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  239. SDPA API.
  240. """
  241. # Adapted from DiffLlamaAttention.forward
  242. def forward(
  243. self,
  244. hidden_states: torch.Tensor,
  245. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  246. attention_mask: torch.Tensor | None = None,
  247. position_ids: torch.LongTensor | None = None,
  248. past_key_values: Cache | None = None,
  249. use_cache: bool = False,
  250. **kwargs,
  251. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  252. bsz, q_len, _ = hidden_states.size()
  253. query_states = self.q_proj(hidden_states)
  254. key_states = self.k_proj(hidden_states)
  255. value_states = self.v_proj(hidden_states)
  256. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  257. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  258. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  259. cos, sin = position_embeddings
  260. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  261. if past_key_values is not None:
  262. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  263. key_states = repeat_kv(key_states, self.num_key_value_groups)
  264. value_states = repeat_kv(value_states, self.num_key_value_groups)
  265. value_states = torch.cat(torch.chunk(value_states, 2, dim=1), dim=-1)
  266. value_states = value_states.repeat(1, 2, 1, 1)
  267. causal_mask = attention_mask
  268. if attention_mask is not None:
  269. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  270. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  271. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  272. is_causal = causal_mask is None and q_len > 1
  273. attn_output = torch.nn.functional.scaled_dot_product_attention(
  274. query_states,
  275. key_states,
  276. value_states,
  277. attn_mask=causal_mask,
  278. dropout_p=self.attention_dropout if self.training else 0.0,
  279. is_causal=is_causal,
  280. )
  281. attn_output1, attn_output2 = torch.chunk(attn_output, 2, dim=1)
  282. lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1, dtype=torch.float32)).to(
  283. query_states.dtype
  284. )
  285. lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1, dtype=torch.float32)).to(
  286. query_states.dtype
  287. )
  288. lambda_full = lambda_1 - lambda_2 + self.lambda_init
  289. attn_output = attn_output1 - lambda_full * attn_output2
  290. attn_output = (1 - self.lambda_init) * self.groupnorm(attn_output)
  291. attn_output = attn_output.transpose(1, 2).contiguous()
  292. attn_output = attn_output.view(bsz, q_len, -1)
  293. attn_output = self.o_proj(attn_output)
  294. return attn_output, None
  295. DIFFLLAMA_ATTENTION_CLASSES = {
  296. "eager": DiffLlamaAttention,
  297. "flash_attention_2": DiffLlamaFlashAttention2,
  298. "sdpa": DiffLlamaSdpaAttention,
  299. }
  300. class DiffLlamaDecoderLayer(LlamaDecoderLayer):
  301. def __init__(self, config: DiffLlamaConfig, layer_idx: int):
  302. super().__init__(config, layer_idx)
  303. self.self_attn = DIFFLLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  304. class DiffLlamaPreTrainedModel(LlamaPreTrainedModel):
  305. _supports_flex_attn = False
  306. _supports_attention_backend = False
  307. @torch.no_grad()
  308. def _init_weights(self, module):
  309. PreTrainedModel._init_weights(self, module)
  310. if isinstance(module, DiffLlamaAttention):
  311. init.normal_(module.lambda_q1, 0, self.config.lambda_std_dev)
  312. init.normal_(module.lambda_k1, 0, self.config.lambda_std_dev)
  313. init.normal_(module.lambda_q2, 0, self.config.lambda_std_dev)
  314. init.normal_(module.lambda_k2, 0, self.config.lambda_std_dev)
  315. class DiffLlamaModel(LlamaModel):
  316. pass
  317. class DiffLlamaForCausalLM(GemmaForCausalLM):
  318. pass
  319. class DiffLlamaForSequenceClassification(LlamaForSequenceClassification):
  320. pass
  321. class DiffLlamaForQuestionAnswering(LlamaForQuestionAnswering):
  322. pass
  323. class DiffLlamaForTokenClassification(LlamaForTokenClassification):
  324. pass
  325. __all__ = [
  326. "DiffLlamaPreTrainedModel",
  327. "DiffLlamaModel",
  328. "DiffLlamaForCausalLM",
  329. "DiffLlamaForSequenceClassification",
  330. "DiffLlamaForQuestionAnswering",
  331. "DiffLlamaForTokenClassification",
  332. ]