configuration_zamba2.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring
  19. @auto_docstring(checkpoint="Zyphra/Zamba2-2.7B")
  20. @strict
  21. class Zamba2Config(PreTrainedConfig):
  22. r"""
  23. mamba_ngroups (`int`, *optional*, defaults to 1):
  24. Number of groups for the evolution matrices of mamba 2.
  25. n_mamba_heads (`int`, *optional*, defaults to 8):
  26. Number of heads for the evolution matrices of mamba 2.
  27. use_conv_bias (`bool`, *optional*, defaults to `True`):
  28. Whether or not to use bias in the convolution layer of the mixer block.
  29. chunk_size (`int`, *optional*, defaults to 256):
  30. Size of the chunks that will comprise the sequence.
  31. use_mem_eff_path (`bool`, *optional*, defaults to `False`):
  32. Whether or not to use the fused conv1d and scan in mamba2 layers.
  33. add_bias_linear (`bool`, *optional*, defaults to `False`):
  34. Flag indicating whether or not to use bias in various layers
  35. num_mem_blocks (`int`, *optional*, defaults to 1):
  36. Number of unshared transformer blocks.
  37. use_shared_attention_adapter (`bool`, *optional*, defaults to `False`):
  38. If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers.
  39. adapter_rank (`int`, *optional*, defaults to 128):
  40. Rank of the adapter in the shared MLP and shared attention layers.
  41. use_mem_rope (`bool`, *optional*, defaults to `False`):
  42. If True, includes RoPE in the shared attention layers.
  43. num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
  44. Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
  45. integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
  46. logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
  47. sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
  48. significantly.
  49. use_long_context (`bool`, *optional*, defaults to `False`):
  50. Activates the context-extended version of Zamba by modifying RoPE.
  51. Example:
  52. ```python
  53. >>> from transformers import Zamba2Model, Zamba2Config
  54. >>> # Initializing a Zamba2-2.7B style configuration
  55. >>> configuration = Zamba2Config()
  56. >>> # Initializing a model from the Zamba2-2.7B style configuration
  57. >>> model = Zamba2Model(configuration)
  58. >>> # Accessing the model configuration
  59. >>> configuration = model.config
  60. ```"""
  61. model_type = "zamba2"
  62. attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"}
  63. keys_to_ignore_at_inference = ["past_key_values"]
  64. vocab_size: int = 32000
  65. max_position_embeddings: int = 4096
  66. hidden_size: int = 2560
  67. num_hidden_layers: int = 54
  68. layers_block_type: list[str] | None = None
  69. mamba_d_state: int = 64
  70. mamba_d_conv: int = 4
  71. mamba_expand: int = 2
  72. mamba_ngroups: int = 1
  73. time_step_min: float = 0.001
  74. time_step_max: float = 0.1
  75. time_step_floor: float = 1e-4
  76. time_step_limit: list[float] | tuple[float, ...] | None = None
  77. n_mamba_heads: int = 8
  78. use_conv_bias: bool = True
  79. chunk_size: int = 256
  80. use_mem_eff_path: bool = False
  81. add_bias_linear: bool = False
  82. intermediate_size: int | None = None
  83. hidden_act: str = "gelu"
  84. num_attention_heads: int = 32
  85. num_key_value_heads: int | None = None
  86. attention_dropout: float | int = 0.0
  87. num_mem_blocks: int = 1
  88. use_shared_attention_adapter: bool = False
  89. adapter_rank: int = 128
  90. use_mem_rope: bool = False
  91. rope_parameters: RopeParameters | dict | None = None
  92. initializer_range: float = 0.02
  93. rms_norm_eps: float = 1e-5
  94. use_cache: bool = True
  95. num_logits_to_keep: int = 1
  96. pad_token_id: int | None = 0
  97. bos_token_id: int | None = 1
  98. eos_token_id: int | list[int] | None = 2
  99. use_long_context: bool = False
  100. tie_word_embeddings: bool = True
  101. def __post_init__(self, **kwargs):
  102. self.intermediate_size = self.intermediate_size or 4 * self.hidden_size
  103. self.attention_hidden_size = 2 * self.hidden_size
  104. self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
  105. self.mamba_headdim = int(self.mamba_expand * self.hidden_size) // self.n_mamba_heads
  106. if self.use_long_context:
  107. self.max_position_embeddings = 16384
  108. if self.num_key_value_heads is None:
  109. self.num_key_value_heads = self.num_attention_heads
  110. self.kv_channels = self.hidden_size // self.num_attention_heads
  111. self.num_query_groups = self.num_attention_heads
  112. # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
  113. if self.layers_block_type is None:
  114. self.layers_block_type = (
  115. ["mamba"]
  116. + (["mamba"] * 5 + ["hybrid"]) * 7
  117. + ["mamba"] * 4
  118. + ["hybrid"]
  119. + ["mamba"] * 3
  120. + ["hybrid"]
  121. + ["mamba"] * 2
  122. )
  123. self.hybrid_layer_ids = [index for index, type in enumerate(self.layers_block_type) if type == "hybrid"]
  124. super().__post_init__(**kwargs)
  125. __all__ = ["Zamba2Config"]