config_mamba.py 489 B

123456789101112131415161718
  1. from dataclasses import dataclass, field
  2. @dataclass
  3. class MambaConfig:
  4. d_model: int = 2560
  5. d_intermediate: int = 0
  6. n_layer: int = 64
  7. vocab_size: int = 50277
  8. ssm_cfg: dict = field(default_factory=dict)
  9. attn_layer_idx: list = field(default_factory=list)
  10. attn_cfg: dict = field(default_factory=dict)
  11. rms_norm: bool = True
  12. residual_in_fp32: bool = True
  13. fused_add_norm: bool = True
  14. pad_vocab_size_multiple: int = 8
  15. tie_embeddings: bool = True