| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/t5gemma/modular_t5gemma.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_t5gemma.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
- #
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any
- from huggingface_hub.dataclasses import strict
- from ...configuration_utils import PreTrainedConfig
- from ...modeling_rope_utils import RopeParameters
- from ...utils import auto_docstring
- @auto_docstring(checkpoint="google/t5_gemma_module-7b")
- @strict
- class T5GemmaModuleConfig(PreTrainedConfig):
- r"""
- query_pre_attn_scalar (`float`, *optional*, defaults to 256):
- scaling factor used on the attention scores
- final_logit_softcapping (`float`, *optional*, defaults to 30.0):
- scaling factor when applying tanh softcapping on the logits.
- attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
- scaling factor when applying tanh softcapping on the attention scores.
- ```python
- >>> from transformers import T5GemmaModuleModel, T5GemmaModuleConfig
- >>> # Initializing a T5GemmaModule t5_gemma_module-7b style configuration
- >>> configuration = T5GemmaModuleConfig()
- >>> # Initializing a model from the t5_gemma_module-7b style configuration
- >>> model = T5GemmaModuleModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- model_type = "t5_gemma_module"
- keys_to_ignore_at_inference = ["past_key_values"]
- base_model_tp_plan = {
- "layers.*.self_attn.q_proj": "colwise",
- "layers.*.self_attn.k_proj": "colwise",
- "layers.*.self_attn.v_proj": "colwise",
- "layers.*.self_attn.o_proj": "rowwise",
- "layers.*.mlp.gate_proj": "colwise",
- "layers.*.mlp.up_proj": "colwise",
- "layers.*.mlp.down_proj": "rowwise",
- }
- base_model_pp_plan = {
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
- "norm": (["hidden_states"], ["hidden_states"]),
- }
- vocab_size: int = 256000
- hidden_size: int = 2304
- intermediate_size: int = 9216
- num_hidden_layers: int = 26
- num_attention_heads: int = 8
- num_key_value_heads: int = 4
- head_dim: int = 256
- hidden_activation: str = "gelu_pytorch_tanh"
- max_position_embeddings: int = 8192
- initializer_range: float = 0.02
- rms_norm_eps: float = 1e-6
- use_cache: bool = True
- pad_token_id: int | None = 0
- eos_token_id: int | list[int] | None = 1
- bos_token_id: int | None = 2
- tie_word_embeddings: bool = True
- rope_parameters: RopeParameters | dict | None = None
- attention_bias: bool = False
- attention_dropout: int | float | None = 0.0
- query_pre_attn_scalar: int = 256
- sliding_window: int | None = 4096
- layer_types: list[str] | None = None
- final_logit_softcapping: float | None = 30.0
- attn_logit_softcapping: float | None = 50.0
- is_decoder: bool = False
- def __post_init__(self, **kwargs):
- if self.layer_types is None:
- self.layer_types = [
- "sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(self.num_hidden_layers)
- ]
- super().__post_init__(**kwargs)
- def validate_architecture(self):
- """Part of `@strict`-powered validation. Validates the architecture of the config."""
- if self.hidden_size % self.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
- f"heads ({self.num_attention_heads})."
- )
- @auto_docstring(checkpoint="google/t5_gemma_module-7b")
- @strict
- class T5GemmaConfig(PreTrainedConfig):
- r"""
- encoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
- Configuration for the encoder.
- decoder (`Union[T5GemmaModuleConfig, dict]`, optional, *optional*):
- Configuration for the decoder.
- Example:
- ```python
- >>> from transformers import T5GemmaConfig, T5GemmaModel
- >>> t5gemma_config = T5GemmaConfig.from_pretrained("google/t5gemma-2b-2b-prefixlm-it")
- >>> model = T5GemmaModel(t5gemma_config)
- ```"""
- model_type = "t5gemma"
- keys_to_ignore_at_inference = ["past_key_values"]
- sub_configs = {"encoder": T5GemmaModuleConfig, "decoder": T5GemmaModuleConfig}
- encoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
- decoder: T5GemmaModuleConfig | dict[Any, Any] | None = None
- is_encoder_decoder: bool = True
- dropout_rate: int | float = 0.0
- classifier_dropout_rate: int | float = 0.0
- attention_dropout: float | int = 0.0
- tie_word_embeddings: bool = True
- vocab_size: int = 256000
- def __post_init__(self, **kwargs):
- if isinstance(self.encoder, dict):
- self.encoder = T5GemmaModuleConfig(**self.encoder)
- elif self.encoder is None:
- self.encoder = T5GemmaModuleConfig()
- if isinstance(self.decoder, dict):
- self.decoder = T5GemmaModuleConfig(**self.decoder)
- elif self.decoder is None:
- self.decoder = T5GemmaModuleConfig()
- self.encoder.is_decoder = False
- self.encoder.dropout_rate = self.dropout_rate
- self.encoder.attention_dropout = self.attention_dropout
- self.decoder.is_decoder = True
- self.decoder.use_cache = True
- self.decoder.dropout_rate = self.dropout_rate
- self.decoder.attention_dropout = self.attention_dropout
- self.decoder.cross_attention_hidden_size = self.encoder.hidden_size
- self.initializer_range = kwargs.pop("initializer_range", self.decoder.initializer_range)
- for special_token_key in ["bos_token_id", "pad_token_id", "eos_token_id"]:
- if special_token_key not in kwargs:
- kwargs[special_token_key] = getattr(self.decoder, special_token_key)
- super().__post_init__(**kwargs)
- __all__ = ["T5GemmaConfig", "T5GemmaModuleConfig"]
|