modular_gemma2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections.abc import Callable
  16. import torch
  17. import torch.nn as nn
  18. from huggingface_hub.dataclasses import strict
  19. from ...activations import ACT2FN
  20. from ...cache_utils import Cache, DynamicCache
  21. from ...configuration_utils import PreTrainedConfig
  22. from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  26. from ...modeling_rope_utils import (
  27. ROPE_INIT_FUNCTIONS,
  28. RopeParameters,
  29. dynamic_rope_update,
  30. )
  31. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  32. from ...processing_utils import Unpack
  33. from ...utils import TransformersKwargs, auto_docstring, logging
  34. from ...utils.generic import maybe_autocast
  35. from ..gemma.modeling_gemma import (
  36. GemmaAttention,
  37. GemmaForCausalLM,
  38. GemmaForSequenceClassification,
  39. GemmaForTokenClassification,
  40. GemmaMLP,
  41. GemmaModel,
  42. GemmaPreTrainedModel,
  43. GemmaRMSNorm,
  44. GemmaRotaryEmbedding,
  45. apply_rotary_pos_emb,
  46. repeat_kv,
  47. )
  48. logger = logging.get_logger(__name__)
  49. @auto_docstring(checkpoint="google/gemma2-7b")
  50. @strict
  51. class Gemma2Config(PreTrainedConfig):
  52. r"""
  53. query_pre_attn_scalar (`float`, *optional*, defaults to 256):
  54. scaling factor used on the attention scores
  55. final_logit_softcapping (`float`, *optional*, defaults to 30.0):
  56. scaling factor when applying tanh softcapping on the logits.
  57. attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
  58. scaling factor when applying tanh softcapping on the attention scores.
  59. use_bidirectional_attention (`bool`, *optional*):
  60. If True, the model will attend to all text tokens instead of using a causal mask.
  61. ```python
  62. >>> from transformers import Gemma2Model, Gemma2Config
  63. >>> # Initializing a Gemma2 gemma2-7b style configuration
  64. >>> configuration = Gemma2Config()
  65. >>> # Initializing a model from the gemma2-7b style configuration
  66. >>> model = Gemma2Model(configuration)
  67. >>> # Accessing the model configuration
  68. >>> configuration = model.config
  69. ```"""
  70. model_type = "gemma2"
  71. keys_to_ignore_at_inference = ["past_key_values"]
  72. base_model_tp_plan = {
  73. "layers.*.self_attn.q_proj": "colwise",
  74. "layers.*.self_attn.k_proj": "colwise",
  75. "layers.*.self_attn.v_proj": "colwise",
  76. "layers.*.self_attn.o_proj": "rowwise",
  77. "layers.*.mlp.gate_proj": "colwise",
  78. "layers.*.mlp.up_proj": "colwise",
  79. "layers.*.mlp.down_proj": "rowwise",
  80. }
  81. base_model_pp_plan = {
  82. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  83. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  84. "norm": (["hidden_states"], ["hidden_states"]),
  85. }
  86. vocab_size: int = 256000
  87. hidden_size: int = 2304
  88. intermediate_size: int = 9216
  89. num_hidden_layers: int = 26
  90. num_attention_heads: int = 8
  91. num_key_value_heads: int = 4
  92. head_dim: int = 256
  93. hidden_activation: str = "gelu_pytorch_tanh"
  94. max_position_embeddings: int = 8192
  95. initializer_range: float = 0.02
  96. rms_norm_eps: float = 1e-6
  97. use_cache: bool = True
  98. pad_token_id: int | None = 0
  99. eos_token_id: int | list[int] | None = 1
  100. bos_token_id: int | None = 2
  101. tie_word_embeddings: bool = True
  102. rope_parameters: RopeParameters | dict | None = None
  103. attention_bias: bool = False
  104. attention_dropout: int | float | None = 0.0
  105. query_pre_attn_scalar: int = 256
  106. sliding_window: int | None = 4096
  107. layer_types: list[str] | None = None
  108. final_logit_softcapping: float | None = 30.0
  109. attn_logit_softcapping: float | None = 50.0
  110. use_bidirectional_attention: bool | None = None
  111. def __post_init__(self, **kwargs):
  112. if self.layer_types is None:
  113. self.layer_types = [
  114. "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
  115. ]
  116. super().__post_init__(**kwargs)
  117. def validate_architecture(self):
  118. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  119. if self.hidden_size % self.num_attention_heads != 0:
  120. raise ValueError(
  121. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  122. f"heads ({self.num_attention_heads})."
  123. )
  124. class Gemma2RMSNorm(GemmaRMSNorm):
  125. pass
  126. class Gemma2MLP(GemmaMLP):
  127. def __init__(self, config):
  128. super().__init__(config)
  129. self.act_fn = ACT2FN[config.hidden_activation]
  130. class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
  131. def __init__(self, config: Gemma2Config, device=None):
  132. nn.Module.__init__()
  133. self.max_seq_len_cached = config.max_position_embeddings
  134. self.original_max_seq_len = config.max_position_embeddings
  135. self.config = config
  136. self.rope_type = self.config.rope_parameters["rope_type"]
  137. rope_init_fn: Callable = self.compute_default_rope_parameters
  138. if self.rope_type != "default":
  139. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  140. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  141. self.register_buffer("inv_freq", inv_freq, persistent=False)
  142. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  143. @torch.no_grad()
  144. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  145. def forward(self, x, position_ids):
  146. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  147. position_ids_expanded = position_ids[:, None, :].float()
  148. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  149. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  150. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  151. emb = torch.cat((freqs, freqs), dim=-1)
  152. cos = emb.cos() * self.attention_scaling
  153. sin = emb.sin() * self.attention_scaling
  154. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  155. def eager_attention_forward(
  156. module: nn.Module,
  157. query: torch.Tensor,
  158. key: torch.Tensor,
  159. value: torch.Tensor,
  160. attention_mask: torch.Tensor | None,
  161. dropout: float | int = 0.0,
  162. scaling: float | None = None,
  163. softcap: float | None = None,
  164. **kwargs,
  165. ) -> tuple[torch.Tensor, torch.Tensor]:
  166. if scaling is None:
  167. scaling = module.head_dim**-0.5
  168. key_states = repeat_kv(key, module.num_key_value_groups)
  169. value_states = repeat_kv(value, module.num_key_value_groups)
  170. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  171. if softcap is not None:
  172. attn_weights = attn_weights / softcap
  173. attn_weights = torch.tanh(attn_weights)
  174. attn_weights = attn_weights * softcap
  175. if attention_mask is not None:
  176. attn_weights = attn_weights + attention_mask
  177. # upcast attention to fp32
  178. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  179. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  180. attn_output = torch.matmul(attn_weights, value_states)
  181. attn_output = attn_output.transpose(1, 2).contiguous()
  182. return attn_output, attn_weights
  183. class Gemma2Attention(GemmaAttention):
  184. def __init__(self, config: Gemma2Config, layer_idx: int):
  185. self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
  186. super().__init__(config, layer_idx)
  187. self.attn_logit_softcapping = self.config.attn_logit_softcapping
  188. self.attention_dropout = self.config.attention_dropout
  189. self.is_causal = not getattr(config, "use_bidirectional_attention", False)
  190. self.scaling = config.query_pre_attn_scalar**-0.5
  191. self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
  192. def forward(
  193. self,
  194. hidden_states: torch.Tensor,
  195. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  196. attention_mask: torch.Tensor | None = None,
  197. past_key_values: Cache | None = None,
  198. **kwargs: Unpack[FlashAttentionKwargs],
  199. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  200. input_shape = hidden_states.shape[:-1]
  201. hidden_shape = (*input_shape, -1, self.head_dim)
  202. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  203. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  204. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  205. cos, sin = position_embeddings
  206. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  207. if past_key_values is not None:
  208. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  209. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  210. self.config._attn_implementation, eager_attention_forward
  211. )
  212. attn_output, attn_weights = attention_interface(
  213. self,
  214. query_states,
  215. key_states,
  216. value_states,
  217. attention_mask,
  218. dropout=self.attention_dropout if self.training else 0.0,
  219. scaling=self.scaling,
  220. sliding_window=self.sliding_window,
  221. softcap=self.attn_logit_softcapping,
  222. **kwargs,
  223. )
  224. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  225. attn_output = self.o_proj(attn_output)
  226. return attn_output, attn_weights
  227. class Gemma2DecoderLayer(GradientCheckpointingLayer):
  228. def __init__(self, config: Gemma2Config, layer_idx: int):
  229. super().__init__()
  230. self.hidden_size = config.hidden_size
  231. self.config = config
  232. self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
  233. self.mlp = Gemma2MLP(config)
  234. self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  235. self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  236. self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  237. self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  238. def forward(
  239. self,
  240. hidden_states: torch.Tensor,
  241. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  242. attention_mask: torch.Tensor | None = None,
  243. position_ids: torch.LongTensor | None = None,
  244. past_key_values: Cache | None = None,
  245. **kwargs,
  246. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  247. residual = hidden_states
  248. hidden_states = self.input_layernorm(hidden_states)
  249. # Self Attention
  250. hidden_states, _ = self.self_attn(
  251. hidden_states=hidden_states,
  252. position_embeddings=position_embeddings,
  253. attention_mask=attention_mask,
  254. position_ids=position_ids,
  255. past_key_values=past_key_values,
  256. **kwargs,
  257. )
  258. hidden_states = self.post_attention_layernorm(hidden_states)
  259. hidden_states = residual + hidden_states
  260. residual = hidden_states
  261. hidden_states = self.pre_feedforward_layernorm(hidden_states)
  262. hidden_states = self.mlp(hidden_states)
  263. hidden_states = self.post_feedforward_layernorm(hidden_states)
  264. hidden_states = residual + hidden_states
  265. return hidden_states
  266. class Gemma2PreTrainedModel(GemmaPreTrainedModel):
  267. pass
  268. class Gemma2Model(GemmaModel):
  269. def __init__(self, config: Gemma2Config):
  270. super().__init__(config)
  271. self.layers = nn.ModuleList(
  272. [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  273. )
  274. self.rotary_emb = Gemma2RotaryEmbedding(config)
  275. def forward(
  276. self,
  277. input_ids: torch.LongTensor | None = None,
  278. attention_mask: torch.Tensor | None = None,
  279. position_ids: torch.LongTensor | None = None,
  280. past_key_values: Cache | None = None,
  281. inputs_embeds: torch.FloatTensor | None = None,
  282. use_cache: bool | None = None,
  283. **kwargs: Unpack[TransformersKwargs],
  284. ) -> BaseModelOutputWithPast:
  285. if (input_ids is None) ^ (inputs_embeds is not None):
  286. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  287. if inputs_embeds is None:
  288. inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
  289. if use_cache and past_key_values is None:
  290. past_key_values = DynamicCache(config=self.config)
  291. if position_ids is None:
  292. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  293. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  294. position_ids = position_ids.unsqueeze(0)
  295. # It may already have been prepared by e.g. `generate`
  296. if not isinstance(causal_mask_mapping := attention_mask, dict):
  297. # Prepare mask arguments
  298. mask_kwargs = {
  299. "config": self.config,
  300. "inputs_embeds": inputs_embeds,
  301. "attention_mask": attention_mask,
  302. "past_key_values": past_key_values,
  303. "position_ids": position_ids,
  304. }
  305. # Create the masks
  306. causal_mask_mapping = {
  307. "full_attention": create_causal_mask(**mask_kwargs),
  308. "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
  309. }
  310. # embed positions
  311. hidden_states = inputs_embeds
  312. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  313. for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
  314. hidden_states = decoder_layer(
  315. hidden_states,
  316. attention_mask=causal_mask_mapping[self.config.layer_types[i]],
  317. position_embeddings=position_embeddings,
  318. position_ids=position_ids,
  319. past_key_values=past_key_values,
  320. **kwargs,
  321. )
  322. hidden_states = self.norm(hidden_states)
  323. return BaseModelOutputWithPast(
  324. last_hidden_state=hidden_states,
  325. past_key_values=past_key_values,
  326. )
  327. class Gemma2ForCausalLM(GemmaForCausalLM):
  328. def __init__(self, config):
  329. super().__init__(config)
  330. self.model = Gemma2Model(config)
  331. self.post_init()
  332. def forward(
  333. self,
  334. input_ids: torch.LongTensor | None = None,
  335. attention_mask: torch.Tensor | None = None,
  336. position_ids: torch.LongTensor | None = None,
  337. past_key_values: Cache | None = None,
  338. inputs_embeds: torch.FloatTensor | None = None,
  339. labels: torch.LongTensor | None = None,
  340. use_cache: bool | None = None,
  341. logits_to_keep: int | torch.Tensor = 0,
  342. **kwargs: Unpack[TransformersKwargs],
  343. ) -> CausalLMOutputWithPast:
  344. r"""
  345. Example:
  346. ```python
  347. >>> from transformers import AutoTokenizer, Gemma2ForCausalLM
  348. >>> model = Gemma2ForCausalLM.from_pretrained("google/gemma-2-9b")
  349. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
  350. >>> prompt = "What is your favorite condiment?"
  351. >>> inputs = tokenizer(prompt, return_tensors="pt")
  352. >>> # Generate
  353. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  354. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  355. "What is your favorite condiment?"
  356. ```"""
  357. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  358. outputs: BaseModelOutputWithPast = self.model(
  359. input_ids=input_ids,
  360. attention_mask=attention_mask,
  361. position_ids=position_ids,
  362. past_key_values=past_key_values,
  363. inputs_embeds=inputs_embeds,
  364. use_cache=use_cache,
  365. **kwargs,
  366. )
  367. hidden_states = outputs.last_hidden_state
  368. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  369. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  370. logits = self.lm_head(hidden_states[:, slice_indices, :])
  371. if self.config.final_logit_softcapping is not None:
  372. logits = logits / self.config.final_logit_softcapping
  373. logits = torch.tanh(logits)
  374. logits = logits * self.config.final_logit_softcapping
  375. loss = None
  376. if labels is not None:
  377. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  378. return CausalLMOutputWithPast(
  379. loss=loss,
  380. logits=logits,
  381. past_key_values=outputs.past_key_values,
  382. hidden_states=outputs.hidden_states,
  383. attentions=outputs.attentions,
  384. )
  385. class Gemma2ForSequenceClassification(GemmaForSequenceClassification):
  386. pass
  387. class Gemma2ForTokenClassification(GemmaForTokenClassification):
  388. pass
  389. __all__ = [
  390. "Gemma2Config",
  391. "Gemma2ForCausalLM",
  392. "Gemma2Model",
  393. "Gemma2PreTrainedModel",
  394. "Gemma2ForSequenceClassification",
  395. "Gemma2ForTokenClassification",
  396. ]