modular_gemma.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from huggingface_hub.dataclasses import strict
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...cache_utils import Cache, DynamicCache
  20. from ...configuration_utils import PreTrainedConfig
  21. from ...masking_utils import create_causal_mask
  22. from ...modeling_outputs import BaseModelOutputWithPast
  23. from ...modeling_rope_utils import RopeParameters
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, logging
  27. from ..llama.modeling_llama import (
  28. LlamaAttention,
  29. LlamaForCausalLM,
  30. LlamaForSequenceClassification,
  31. LlamaForTokenClassification,
  32. LlamaMLP,
  33. LlamaModel,
  34. LlamaPreTrainedModel,
  35. LlamaRotaryEmbedding,
  36. )
  37. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  38. SPIECE_UNDERLINE = "▁"
  39. logger = logging.get_logger(__name__)
  40. @auto_docstring(checkpoint="google/gemma-7b")
  41. @strict
  42. class GemmaConfig(PreTrainedConfig):
  43. r"""
  44. use_bidirectional_attention (`bool`, *optional*):
  45. If True, the model will attend to all text tokens instead of using a causal mask.
  46. ```python
  47. >>> from transformers import GemmaModel, GemmaConfig
  48. >>> # Initializing a Gemma gemma-7b style configuration
  49. >>> configuration = GemmaConfig()
  50. >>> # Initializing a model from the gemma-7b style configuration
  51. >>> model = GemmaModel(configuration)
  52. >>> # Accessing the model configuration
  53. >>> configuration = model.config
  54. ```"""
  55. model_type = "gemma"
  56. keys_to_ignore_at_inference = ["past_key_values"]
  57. base_model_tp_plan = {
  58. "layers.*.self_attn.q_proj": "colwise",
  59. "layers.*.self_attn.k_proj": "colwise",
  60. "layers.*.self_attn.v_proj": "colwise",
  61. "layers.*.self_attn.o_proj": "rowwise",
  62. "layers.*.mlp.gate_proj": "colwise",
  63. "layers.*.mlp.up_proj": "colwise",
  64. "layers.*.mlp.down_proj": "rowwise",
  65. }
  66. base_model_pp_plan = {
  67. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  68. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  69. "norm": (["hidden_states"], ["hidden_states"]),
  70. }
  71. vocab_size: int = 256000
  72. hidden_size: int = 3072
  73. intermediate_size: int = 24576
  74. num_hidden_layers: int = 28
  75. num_attention_heads: int = 16
  76. num_key_value_heads: int = 16
  77. head_dim: int = 256
  78. hidden_act: str = "gelu_pytorch_tanh"
  79. max_position_embeddings: int = 8192
  80. initializer_range: float = 0.02
  81. rms_norm_eps: float = 1e-6
  82. use_cache: bool = True
  83. pad_token_id: int | None = 0
  84. eos_token_id: int | list[int] | None = 1
  85. bos_token_id: int | None = 2
  86. tie_word_embeddings: bool = True
  87. rope_parameters: RopeParameters | dict | None = None
  88. attention_bias: bool = False
  89. attention_dropout: float | int = 0.0
  90. use_bidirectional_attention: bool | None = None
  91. class GemmaTextScaledWordEmbedding(nn.Embedding):
  92. """
  93. This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
  94. """
  95. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
  96. super().__init__(num_embeddings, embedding_dim, padding_idx)
  97. self.scalar_embed_scale = embed_scale
  98. self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
  99. def forward(self, input_ids: torch.Tensor):
  100. return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
  101. class GemmaRMSNorm(nn.Module):
  102. def __init__(self, dim: int, eps: float = 1e-6):
  103. super().__init__()
  104. self.eps = eps
  105. self.weight = nn.Parameter(torch.zeros(dim))
  106. def _norm(self, x):
  107. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  108. def forward(self, x):
  109. output = self._norm(x.float())
  110. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  111. # See https://github.com/huggingface/transformers/pull/29402
  112. output = output * (1.0 + self.weight.float())
  113. return output.type_as(x)
  114. def extra_repr(self):
  115. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  116. class GemmaMLP(LlamaMLP):
  117. def __init__(self, config):
  118. super().__init__(config)
  119. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  120. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  121. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  122. class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
  123. pass
  124. class GemmaAttention(LlamaAttention):
  125. """Multi-headed attention from 'Attention Is All You Need' paper"""
  126. def __init__(self, config: GemmaConfig, layer_idx: int):
  127. super().__init__()
  128. self.is_causal = not getattr(config, "use_bidirectional_attention", False)
  129. class GemmaPreTrainedModel(LlamaPreTrainedModel):
  130. @torch.no_grad()
  131. def _init_weights(self, module):
  132. PreTrainedModel._init_weights(self, module)
  133. # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
  134. if "RMSNorm" in module.__class__.__name__:
  135. init.zeros_(module.weight)
  136. elif isinstance(module, GemmaTextScaledWordEmbedding):
  137. init.constant_(module.embed_scale, module.scalar_embed_scale)
  138. class GemmaModel(LlamaModel):
  139. def __init__(self, config: GemmaConfig):
  140. super().__init__(config)
  141. # Gemma3 downcasts the below to bfloat16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
  142. self.embed_tokens = GemmaTextScaledWordEmbedding(
  143. config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5
  144. )
  145. def forward(
  146. self,
  147. input_ids: torch.LongTensor | None = None,
  148. attention_mask: torch.Tensor | None = None,
  149. position_ids: torch.LongTensor | None = None,
  150. past_key_values: Cache | None = None,
  151. inputs_embeds: torch.FloatTensor | None = None,
  152. use_cache: bool | None = None,
  153. **kwargs: Unpack[TransformersKwargs],
  154. ) -> BaseModelOutputWithPast:
  155. if (input_ids is None) ^ (inputs_embeds is not None):
  156. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  157. if inputs_embeds is None:
  158. inputs_embeds = self.embed_tokens(input_ids)
  159. if use_cache and past_key_values is None:
  160. past_key_values = DynamicCache(config=self.config)
  161. if position_ids is None:
  162. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  163. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  164. position_ids = position_ids.unsqueeze(0)
  165. causal_mask = create_causal_mask(
  166. config=self.config,
  167. inputs_embeds=inputs_embeds,
  168. attention_mask=attention_mask,
  169. past_key_values=past_key_values,
  170. position_ids=position_ids,
  171. )
  172. # embed positions
  173. hidden_states = inputs_embeds
  174. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  175. for decoder_layer in self.layers[: self.config.num_hidden_layers]:
  176. hidden_states = decoder_layer(
  177. hidden_states,
  178. attention_mask=causal_mask,
  179. position_ids=position_ids,
  180. past_key_values=past_key_values,
  181. use_cache=use_cache,
  182. position_embeddings=position_embeddings,
  183. **kwargs,
  184. )
  185. hidden_states = self.norm(hidden_states)
  186. return BaseModelOutputWithPast(
  187. last_hidden_state=hidden_states,
  188. past_key_values=past_key_values if use_cache else None,
  189. )
  190. class GemmaForCausalLM(LlamaForCausalLM):
  191. def forward(**super_kwargs):
  192. r"""
  193. Example:
  194. ```python
  195. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  196. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  197. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  198. >>> prompt = "What is your favorite condiment?"
  199. >>> inputs = tokenizer(prompt, return_tensors="pt")
  200. >>> # Generate
  201. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  202. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  203. "What is your favorite condiment?"
  204. ```"""
  205. return super().forward(**super_kwargs)
  206. class GemmaForSequenceClassification(LlamaForSequenceClassification):
  207. pass
  208. class GemmaForTokenClassification(LlamaForTokenClassification):
  209. pass
  210. __all__ = [
  211. "GemmaConfig",
  212. "GemmaModel",
  213. "GemmaForCausalLM",
  214. "GemmaForSequenceClassification",
  215. "GemmaForTokenClassification",
  216. "GemmaPreTrainedModel",
  217. ]