configuration_zamba.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Zamba model configuration"""
  15. import math
  16. from huggingface_hub.dataclasses import strict
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring
  19. @auto_docstring(checkpoint="Zyphra/Zamba-7B-v1")
  20. @strict
  21. class ZambaConfig(PreTrainedConfig):
  22. r"""
  23. attention_hidden_size (`int`, *optional*):
  24. Dimension of the hidden representations of the inputs to the Attention layer.
  25. attention_head_dim (`int`, *optional*):
  26. Dimension of the attention head in the Transformer decoder.
  27. n_mamba_heads (`int`, *optional*, defaults to 2):
  28. Number of mamba heads for each mamba layer.
  29. hidden_mamba_act (`str` or `function`, *optional*, defaults to `"silu"`):
  30. The non-linear activation function (function or string) in the mamba layer.
  31. num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
  32. Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
  33. integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
  34. logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
  35. sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
  36. significantly.
  37. attn_layer_period (`int`, *optional*, defaults to 6):
  38. Once in this many layers, we will have a shared attention layer
  39. attn_layer_offset (`int`, *optional*, defaults to 4):
  40. Offset of the shared attention layer
  41. use_mamba_kernels (`bool`, *optional*, defaults to `True`):
  42. Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
  43. `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
  44. `True` and kernels are not available
  45. mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
  46. Rank of the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
  47. """
  48. model_type = "zamba"
  49. keys_to_ignore_at_inference = ["past_key_values"]
  50. attribute_map = {"layer_types": "layers_block_type", "head_dim": "attention_head_dim"}
  51. vocab_size: int = 32000
  52. tie_word_embeddings: bool = True
  53. hidden_size: int = 3712
  54. attention_hidden_size: int | None = None
  55. intermediate_size: int = 14848
  56. num_hidden_layers: int = 76
  57. num_attention_heads: int = 16
  58. attention_head_dim: int | None = None
  59. num_key_value_heads: int = 16
  60. n_mamba_heads: int = 2
  61. hidden_act: str = "gelu"
  62. hidden_mamba_act: str = "silu"
  63. initializer_range: float = 0.02
  64. rms_norm_eps: float = 1e-5
  65. use_cache: bool = True
  66. num_logits_to_keep: int = 1
  67. pad_token_id: int | None = 0
  68. bos_token_id: int | None = 1
  69. eos_token_id: int | list[int] | None = 2
  70. max_position_embeddings: int = 4096
  71. attention_dropout: float | int = 0.0
  72. attn_layer_period: int = 6
  73. attn_layer_offset: int = 4
  74. use_mamba_kernels: bool = True
  75. mamba_d_state: int = 16
  76. mamba_d_conv: int = 4
  77. mamba_expand: int = 2
  78. mamba_dt_rank: str | int = "auto"
  79. time_step_min: float = 0.001
  80. time_step_max: float = 0.1
  81. time_step_floor: float = 1e-4
  82. mamba_conv_bias: bool = True
  83. mamba_proj_bias: bool = False
  84. def __post_init__(self, **kwargs):
  85. self.attention_hidden_size = self.attention_hidden_size or 2 * self.hidden_size
  86. self.attention_head_dim = self.attention_head_dim or 2 * self.hidden_size // self.num_attention_heads
  87. self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if self.mamba_dt_rank == "auto" else self.mamba_dt_rank
  88. self.layers_block_type = self._layers_block_type(
  89. self.num_hidden_layers, self.attn_layer_period, self.attn_layer_offset
  90. )
  91. super().__post_init__(**kwargs)
  92. def validate_architecture(self):
  93. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  94. if (self.mamba_expand * self.hidden_size) % self.n_mamba_heads != 0:
  95. raise ValueError("`intermediate_size` should be divisible by `n_mamba_heads`.")
  96. def _layers_block_type(self, num_hidden_layers, attn_layer_period, attn_layer_offset):
  97. layers = [
  98. "mamba",
  99. "mamba",
  100. "hybrid",
  101. ] + ["hybrid" if i % attn_layer_period == attn_layer_offset else "mamba" for i in range(num_hidden_layers - 3)]
  102. return layers
  103. __all__ = ["ZambaConfig"]