METADATA 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. Metadata-Version: 2.4
  2. Name: mamba_ssm
  3. Version: 2.2.4
  4. Summary: Efficient implementation of selective state space models (Mamba)
  5. Home-page: https://github.com/state-spaces/mamba
  6. Author: Tri Dao, Albert Gu
  7. Author-email: Albert Gu <albertgu@stanford.edu>, Tri Dao <trid@cs.stanford.edu>
  8. License: MIT
  9. Project-URL: Homepage, https://github.com/state-spaces/mamba
  10. Project-URL: Bug Tracker, https://github.com/state-spaces/mamba/issues
  11. Classifier: Development Status :: 4 - Beta
  12. Classifier: Intended Audience :: Developers
  13. Classifier: Intended Audience :: Science/Research
  14. Classifier: License :: OSI Approved :: MIT License
  15. Classifier: Programming Language :: Python :: 3.8
  16. Classifier: Programming Language :: Python :: 3.9
  17. Classifier: Programming Language :: Python :: 3.10
  18. Classifier: Programming Language :: Python :: 3.11
  19. Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
  20. Requires-Python: >=3.9
  21. Description-Content-Type: text/markdown
  22. License-File: LICENSE
  23. License-File: AUTHORS
  24. Requires-Dist: torch>=2.4.0
  25. Requires-Dist: einops
  26. Requires-Dist: transformers>=4.51.3
  27. Requires-Dist: triton-windows; platform_system == "Windows"
  28. Requires-Dist: triton; platform_system != "Windows"
  29. Dynamic: author
  30. Dynamic: home-page
  31. Dynamic: license-file
  32. Dynamic: requires-python
  33. # Mamba
  34. ![Mamba](assets/selection.png "Selective State Space")
  35. > **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
  36. > Albert Gu*, Tri Dao*\
  37. > Paper: https://arxiv.org/abs/2312.00752
  38. ![Mamba-2](assets/ssd_algorithm.png "State Space Dual Model")
  39. > **Transformers are SSMs: Generalized Models and Efficient Algorithms**\
  40. > **Through Structured State Space Duality**\
  41. > Tri Dao*, Albert Gu*\
  42. > Paper: https://arxiv.org/abs/2405.21060
  43. ## About
  44. Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers.
  45. It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4),
  46. with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention).
  47. ## Installation
  48. - [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
  49. - `pip install mamba-ssm`: the core Mamba package.
  50. - `pip install mamba-ssm[causal-conv1d]`: To install core Mamba package and causal-conv1d.
  51. - `pip install mamba-ssm[dev]`: To install core Mamba package and dev depdencies.
  52. It can also be built from source with `pip install .` from this repository.
  53. Try passing `--no-build-isolation` to `pip` if installation encounters difficulties either when building from source or installing from PyPi. Common `pip` complaints that can be resolved in this way include PyTorch versions, but other cases exist as well.
  54. Other requirements:
  55. - Linux
  56. - NVIDIA GPU
  57. - PyTorch 1.12+
  58. - CUDA 11.6+
  59. For AMD cards, see additional prerequisites below.
  60. ## Usage
  61. We expose several levels of interface with the Mamba model.
  62. ### Selective SSM
  63. Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
  64. Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
  65. ### Mamba Block
  66. The main module of this repository is the Mamba architecture block wrapping the selective SSM.
  67. Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
  68. Usage:
  69. ``` python
  70. import torch
  71. from mamba_ssm import Mamba
  72. batch, length, dim = 2, 64, 16
  73. x = torch.randn(batch, length, dim).to("cuda")
  74. model = Mamba(
  75. # This module uses roughly 3 * expand * d_model^2 parameters
  76. d_model=dim, # Model dimension d_model
  77. d_state=16, # SSM state expansion factor
  78. d_conv=4, # Local convolution width
  79. expand=2, # Block expansion factor
  80. ).to("cuda")
  81. y = model(x)
  82. assert y.shape == x.shape
  83. ```
  84. ### Mamba-2
  85. The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
  86. A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
  87. The usage is similar to Mamba(-1):
  88. ``` python
  89. from mamba_ssm import Mamba2
  90. model = Mamba2(
  91. # This module uses roughly 3 * expand * d_model^2 parameters
  92. d_model=dim, # Model dimension d_model
  93. d_state=64, # SSM state expansion factor, typically 64 or 128
  94. d_conv=4, # Local convolution width
  95. expand=2, # Block expansion factor
  96. ).to("cuda")
  97. y = model(x)
  98. assert y.shape == x.shape
  99. ```
  100. #### SSD
  101. A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions
  102. is at [modules/ssd_minimal.py](mamba_ssm/modules/ssd_minimal.py).
  103. ### Mamba Language Model
  104. Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
  105. Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
  106. This is an example of how to integrate Mamba into an end-to-end neural network.
  107. This example is used in the generation scripts below.
  108. ## Pretrained Models
  109. Pretrained models are uploaded to
  110. [Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
  111. `mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
  112. `mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
  113. (trained on 600B tokens on the SlimPajama dataset).
  114. The models will be autodownloaded by the generation script below.
  115. These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
  116. | Parameters | Layers | Model dim. |
  117. |------------|--------|------------|
  118. | 130M | 24 | 768 |
  119. | 370M | 48 | 1024 |
  120. | 790M | 48 | 1536 |
  121. | 1.4B | 48 | 2048 |
  122. | 2.8B | 64 | 2560 |
  123. (The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
  124. Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
  125. Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
  126. ## Evaluations
  127. To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
  128. we use the
  129. [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
  130. library.
  131. 1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
  132. 2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
  133. ``` sh
  134. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
  135. python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
  136. ```
  137. To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
  138. ``` sh
  139. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
  140. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
  141. ```
  142. To run evaluations on Mamba-2 models, simply replace the model names:
  143. ``` sh
  144. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
  145. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
  146. lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
  147. ```
  148. Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
  149. ## Inference
  150. The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
  151. 1. autoloads a model from the Hugging Face Hub,
  152. 2. generates completions of a user-specified prompt,
  153. 3. benchmarks the inference speed of this generation.
  154. Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
  155. ### Examples
  156. To test generation latency (e.g. batch size = 1) with different sampling strategies:
  157. ``` sh
  158. python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
  159. python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
  160. python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
  161. ```
  162. To test generation throughput with random prompts (e.g. large batch size):
  163. ``` sh
  164. python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
  165. python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
  166. ```
  167. With Mamba-2, you just need to change the model name:
  168. ``` sh
  169. python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
  170. ```
  171. ## Troubleshooting
  172. ### Precision
  173. Our models were trained using PyTorch [AMP](https://pytorch.org/docs/stable/amp.html) for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary.
  174. On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).
  175. We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities,
  176. as a first step please try a framework storing parameters in fp32 (such as AMP).
  177. ### Initialization
  178. Some parts of the model have initializations inherited from prior work on S4 models.
  179. For [example](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L102), the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection.
  180. However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in `nn.Linear` modules to zero).
  181. If this is the case, you may have to add custom logic (e.g. this [line](https://github.com/state-spaces/mamba/blob/f0affcf69f06d1d06cef018ff640bf080a11c421/mamba_ssm/modules/mamba_simple.py#L104) turns off re-initializing in our trainer, but would be a no-op in any other framework)
  182. that is specific to the training framework.
  183. ## Additional Prerequisites for AMD cards
  184. ### Patching ROCm
  185. If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.
  186. 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation.
  187. 2. Apply the Patch. Run with `sudo` in case you encounter permission issues.
  188. ```bash
  189. patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch
  190. ```
  191. ## Citation
  192. If you use this codebase, or otherwise find our work valuable, please cite Mamba:
  193. ```
  194. @article{mamba,
  195. title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  196. author={Gu, Albert and Dao, Tri},
  197. journal={arXiv preprint arXiv:2312.00752},
  198. year={2023}
  199. }
  200. @inproceedings{mamba2,
  201. title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
  202. author={Dao, Tri and Gu, Albert},
  203. booktitle={International Conference on Machine Learning (ICML)},
  204. year={2024}
  205. }
  206. ```