configuration_xlstm.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # Copyright 2025 NXAI GmbH. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """xLSTM configuration."""
  15. from huggingface_hub.dataclasses import strict
  16. from ...configuration_utils import PreTrainedConfig
  17. from ...utils import auto_docstring, is_xlstm_available
  18. if is_xlstm_available():
  19. from xlstm.xlstm_large.model import (
  20. BackendModeType,
  21. ChunkwiseKernelType,
  22. DtypeType,
  23. SequenceKernelType,
  24. StepKernelType,
  25. WeightModeType,
  26. round_up_to_next_multiple_of,
  27. xLSTMLargeConfig,
  28. )
  29. external_xlstm = True
  30. else:
  31. from typing import Literal
  32. BackendModeType = Literal["train", "train_with_padding", "inference"]
  33. ChunkwiseKernelType = Literal[
  34. "chunkwise--native_autograd",
  35. "parallel--native_autograd",
  36. ]
  37. DtypeType = Literal["float32", "bfloat16", "float16"]
  38. SequenceKernelType = Literal["native_sequence__native"]
  39. StepKernelType = Literal["native"]
  40. WeightModeType = Literal["single", "fused"]
  41. def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
  42. """Rounds up x to the next multiple of multiple_of."""
  43. return int(((x + multiple_of - 1) // multiple_of) * multiple_of)
  44. external_xlstm = False
  45. @auto_docstring(checkpoint="NX-AI/xLSTM-7b")
  46. @strict
  47. class xLSTMConfig(PreTrainedConfig):
  48. r"""
  49. num_blocks (int, optional, *optional*, defaults to 32):
  50. Number of blocks of the xLSTM model, use num_hidden_layers if None.
  51. num_heads (int, optional, *optional*, defaults to 8):
  52. Number of heads for the xLSTM Layer/Cell.
  53. use_bias (bool, optional, *optional*, defaults to `False`):
  54. Whether to use biases in the xLSTM model.
  55. norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`):
  56. Whether to force the float32 norm reduction op to be done in fp32 precision.
  57. add_out_norm (bool, optional, *optional*, defaults to `True`):
  58. Whether to add an output norm after the blocks before the LMHead.
  59. qk_dim_factor (float, optional, *optional*, defaults to 0.5):
  60. Scale factor for the query and key dimension.
  61. v_dim_factor (float, optional, *optional*, defaults to 1.0):
  62. Scale factor for the value dimension.
  63. chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`):
  64. Kernel type for chunkwise processing mode.
  65. sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`):
  66. Kernel type for sequence processing mode.
  67. step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`):
  68. Kernel type for step processing mode.
  69. mode (BackendModeType, optional, *optional*, defaults to `"inference"`):
  70. Operation mode (inference is needed for generation).
  71. chunk_size (int, optional, *optional*, defaults to 64):
  72. Internal chunk size.
  73. return_last_states (bool, optional, *optional*, defaults to `True`):
  74. If to return the last states / cache internally. Needed as True for generation.
  75. autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`):
  76. Kernel dtype for the states.
  77. inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`):
  78. Kernel dtype for states in inference.
  79. ffn_proj_factor (float, optional, *optional*, defaults to 2.667):
  80. Size factor of the post-up projection gated Feed Forward network.
  81. ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64):
  82. Size factor round value of the post-up projection gated Feed Forward network.
  83. gate_soft_cap (float, optional, *optional*, defaults to 15.0):
  84. Gate soft cap scale.
  85. output_logit_soft_cap (float, optional, *optional*, defaults to 30.0):
  86. Output logit soft cap scale.
  87. weight_mode (`Literal`, *optional*, defaults to `"single"`):
  88. Whether parallel linear layers are separated or fused (single).
  89. max_inference_chunksize (int, optional, *optional*, defaults to 16384):
  90. Limit the chunk size for inference to save memory.
  91. Example:
  92. ```python
  93. >>> from transformers import xLSTMConfig, xLSTMModel
  94. >>> # Initializing a xLSTM configuration
  95. >>> configuration = xLSTMConfig()
  96. >>> # Initializing a model (with random weights) from the configuration
  97. >>> model = xLSTMModel(configuration)
  98. >>> # Accessing the model configuration
  99. >>> configuration = model.config
  100. ```"""
  101. model_type = "xlstm"
  102. vocab_size: int = 50304
  103. hidden_size: int = 4096
  104. embedding_dim: int | None = None
  105. num_hidden_layers: int = 32
  106. num_blocks: int | None = None
  107. num_heads: int = 8
  108. use_bias: bool = False
  109. norm_reduction_force_float32: bool = True
  110. tie_word_embeddings: bool = False
  111. add_out_norm: bool = True
  112. norm_eps: float = 1e-6
  113. qk_dim_factor: float = 0.5
  114. v_dim_factor: float = 1.0
  115. chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd"
  116. sequence_kernel: SequenceKernelType = "native_sequence__native"
  117. step_kernel: StepKernelType = "native"
  118. mode: BackendModeType = "inference"
  119. chunk_size: int = 64
  120. return_last_states: bool = True
  121. autocast_kernel_dtype: DtypeType = "bfloat16"
  122. eps: float = 1e-6
  123. inference_state_dtype: DtypeType = "float32"
  124. ffn_proj_factor: float = 2.667
  125. ffn_round_up_to_multiple_of: int = 64
  126. gate_soft_cap: float = 15.0
  127. output_logit_soft_cap: float = 30.0
  128. weight_mode: WeightModeType = "single"
  129. use_cache: bool = True
  130. pad_token_id: int | None = 1
  131. bos_token_id: int | None = 0
  132. eos_token_id: int | list[int] | None = 2
  133. max_inference_chunksize: int = 16384
  134. def __post_init__(self, **kwargs):
  135. self.hidden_size = self.hidden_size if self.hidden_size is not None else self.embedding_dim
  136. self.embedding_dim = self.embedding_dim if self.embedding_dim is not None else self.hidden_size
  137. self.num_hidden_layers = self.num_hidden_layers if self.num_hidden_layers is not None else self.num_blocks
  138. self.num_blocks = self.num_blocks if self.num_blocks is not None else self.num_hidden_layers
  139. super().__post_init__(**kwargs)
  140. @property
  141. def qk_dim(self):
  142. return round_up_to_next_multiple_of(
  143. self.hidden_size * self.qk_dim_factor,
  144. multiple_of=64,
  145. )
  146. @property
  147. def v_dim(self):
  148. return round_up_to_next_multiple_of(
  149. self.hidden_size * self.v_dim_factor,
  150. multiple_of=64,
  151. )
  152. @property
  153. def qk_head_dim(self):
  154. return self.qk_dim // self.num_heads
  155. @property
  156. def v_head_dim(self):
  157. return self.v_dim // self.num_heads
  158. def to_xlstm_block_config(self):
  159. if external_xlstm:
  160. return xLSTMLargeConfig(
  161. vocab_size=self.vocab_size,
  162. embedding_dim=self.hidden_size,
  163. num_blocks=self.num_hidden_layers,
  164. num_heads=self.num_heads,
  165. use_bias=self.use_bias,
  166. add_out_norm=self.add_out_norm,
  167. norm_eps=self.norm_eps,
  168. norm_reduction_force_float32=self.norm_reduction_force_float32,
  169. # mlstm_layer
  170. qk_dim_factor=self.qk_dim_factor,
  171. v_dim_factor=self.v_dim_factor,
  172. # mlstm backend
  173. chunkwise_kernel=self.chunkwise_kernel,
  174. sequence_kernel=self.sequence_kernel,
  175. step_kernel=self.step_kernel,
  176. mode=self.mode,
  177. chunk_size=self.chunk_size,
  178. return_last_states=self.return_last_states,
  179. autocast_kernel_dtype=self.autocast_kernel_dtype,
  180. eps=self.eps,
  181. inference_state_dtype=self.inference_state_dtype,
  182. # feedforward
  183. ffn_proj_factor=self.ffn_proj_factor,
  184. ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of,
  185. # capping
  186. gate_soft_cap=self.gate_soft_cap,
  187. output_logit_soft_cap=self.output_logit_soft_cap,
  188. weight_mode=self.weight_mode,
  189. )
  190. else:
  191. return self
  192. __all__ = ["xLSTMConfig"]