configuration_dbrx.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """DBRX model configuration"""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...modeling_rope_utils import RopeParameters
  18. from ...utils import auto_docstring
  19. @strict
  20. @auto_docstring(
  21. custom_intro="This config is used to instantiate attention layers.",
  22. checkpoint="transformers-community/dbrx-instruct",
  23. )
  24. class DbrxAttentionConfig(PreTrainedConfig):
  25. r"""
  26. attn_pdrop (`float`, *optional*, defaults to 0.0):
  27. The dropout probability for the attention layers.
  28. clip_qkv (`float`, *optional*):
  29. If set, clip the queries, keys, and values in the attention layer to this value.
  30. kv_n_heads (`int`, *optional*, defaults to 1):
  31. For grouped_query_attention only, allow user to specify number of kv heads.
  32. """
  33. base_config_key = "attn_config"
  34. attn_pdrop: float | int = 0.0
  35. clip_qkv: int | float | None = None
  36. kv_n_heads: int = 1
  37. @strict
  38. @auto_docstring(
  39. custom_intro="This config is used to instantiate feedforward layers.",
  40. checkpoint="transformers-community/dbrx-instruct",
  41. )
  42. class DbrxFFNConfig(PreTrainedConfig):
  43. r"""
  44. ffn_act_fn (`dict`, *optional*, defaults to `None`):
  45. A dict specifying activation function for the FFN.
  46. The dict should have a key 'name' with the value being the name of the activation function along with
  47. any additional keyword arguments. If `None`, then set to `{"name": "silu"}`.
  48. ffn_hidden_size (`int`, *optional*, defaults to 3584):
  49. The hidden size of the feedforward network.
  50. moe_num_experts (`int`, *optional*, defaults to 4):
  51. The number of experts in the mixture of experts layer.
  52. moe_top_k (`int`, *optional*, defaults to 1):
  53. The number of experts to use in the mixture of experts layer.
  54. moe_jitter_eps (`float`, *optional*, defaults to `None`):
  55. If not `None`, the jitter epsilon for the mixture of experts layer.
  56. moe_loss_weight (`float`, *optional*, defaults to 0.01):
  57. The loss weight for the mixture of experts layer.
  58. moe_normalize_expert_weights (`float`, *optional*, defaults to 1.0):
  59. The normalization factor for the expert weights.
  60. """
  61. base_config_key = "ffn_config"
  62. hidden_size: int = 6144
  63. ffn_act_fn: dict | None = None
  64. ffn_hidden_size: int = 3584
  65. moe_num_experts: int = 4
  66. moe_top_k: int = 1
  67. moe_jitter_eps: float | None = None
  68. moe_loss_weight: float = 0.01
  69. moe_normalize_expert_weights: float | None = 1.0
  70. def __post_init__(self, **kwargs):
  71. if self.ffn_act_fn is None:
  72. self.ffn_act_fn = {"name": "silu"}
  73. for k in [
  74. "model_type",
  75. "attn_implementation",
  76. "experts_implementation",
  77. "transformers_version",
  78. "_commit_hash",
  79. "torch_dtype",
  80. "dtype",
  81. ]:
  82. if k in kwargs:
  83. kwargs.pop(k)
  84. if len(kwargs) != 0:
  85. raise ValueError(f"Found unknown {kwargs=}")
  86. super().__post_init__(**kwargs)
  87. @auto_docstring(checkpoint="transformers-community/dbrx-instruct")
  88. @strict
  89. class DbrxConfig(PreTrainedConfig):
  90. r"""
  91. max_seq_len (`int`, *optional*, defaults to 2048):
  92. The maximum sequence length of the model.
  93. attn_config (`dict`, *optional*):
  94. A dictionary used to configure the model's attention module.
  95. ffn_config (`dict`, *optional*):
  96. A dictionary used to configure the model's FFN module.
  97. Example:
  98. ```python
  99. >>> from transformers import DbrxConfig, DbrxModel
  100. >>> # Initializing a Dbrx configuration
  101. >>> configuration = DbrxConfig(n_layers=2, d_model=256, n_heads=8, vocab_size=128)
  102. >>> # Initializing a model (with random weights) from the configuration
  103. >>> model = DbrxModel(configuration)
  104. >>> # Accessing the model configuration
  105. >>> configuration = model.config
  106. ```
  107. """
  108. model_type = "dbrx"
  109. sub_configs = {"attn_config": DbrxAttentionConfig, "ffn_config": DbrxFFNConfig}
  110. attribute_map = {
  111. "num_attention_heads": "n_heads",
  112. "hidden_size": "d_model",
  113. "num_hidden_layers": "n_layers",
  114. "max_position_embeddings": "max_seq_len",
  115. }
  116. d_model: int | None = 2048
  117. n_heads: int | None = 16
  118. n_layers: int | None = 24
  119. max_seq_len: int | None = 2048
  120. vocab_size: int = 32000
  121. resid_pdrop: float | None = 0.0
  122. emb_pdrop: float | None = 0.0
  123. attn_config: DbrxAttentionConfig | dict | None = None
  124. ffn_config: DbrxFFNConfig | dict | None = None
  125. use_cache: bool = True
  126. initializer_range: float = 0.02
  127. output_router_logits: bool | None = False
  128. rope_parameters: RopeParameters | dict | None = None
  129. pad_token_id: int | None = None
  130. bos_token_id: int | None = None
  131. eos_token_id: int | list[int] | None = None
  132. tie_word_embeddings: bool = False
  133. def __post_init__(self, **kwargs):
  134. if self.attn_config is None:
  135. self.attn_config = DbrxAttentionConfig()
  136. elif isinstance(self.attn_config, dict):
  137. self.attn_config = DbrxAttentionConfig(**self.attn_config)
  138. if self.ffn_config is None:
  139. self.ffn_config = DbrxFFNConfig()
  140. elif isinstance(self.ffn_config, dict):
  141. self.ffn_config = DbrxFFNConfig(**self.ffn_config)
  142. self.num_key_value_heads = self.attn_config.kv_n_heads
  143. super().__post_init__(**kwargs)
  144. def validate_architecture(self):
  145. """Part of `@strict`-powered validation. Validates the architecture of the config."""
  146. if self.tie_word_embeddings:
  147. raise ValueError("tie_word_embeddings is not supported for DBRX models.")
  148. __all__ = ["DbrxConfig"]