Skip to content

Commit 33dc96c

Browse files
committed
First release
1 parent 90bbb62 commit 33dc96c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+5491
-2
lines changed

.github/workflows/publish.yaml

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# This workflow will:
2+
# - Create a new Github release
3+
# - Build wheels for supported architectures
4+
# - Deploy the wheels to the Github release
5+
# - Release the static code to PyPi
6+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
7+
8+
name: Build wheels and deploy
9+
10+
on:
11+
create:
12+
tags:
13+
- v*
14+
15+
jobs:
16+
17+
setup_release:
18+
name: Create Release
19+
runs-on: ubuntu-latest
20+
steps:
21+
- name: Get the tag version
22+
id: extract_branch
23+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
24+
shell: bash
25+
26+
- name: Create Release
27+
id: create_release
28+
uses: actions/create-release@v1
29+
env:
30+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
31+
with:
32+
tag_name: ${{ steps.extract_branch.outputs.branch }}
33+
release_name: ${{ steps.extract_branch.outputs.branch }}
34+
35+
build_wheels:
36+
name: Build Wheel
37+
needs: setup_release
38+
runs-on: ${{ matrix.os }}
39+
40+
strategy:
41+
fail-fast: false
42+
matrix:
43+
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
44+
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
45+
os: [ubuntu-20.04]
46+
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
47+
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231127']
48+
cuda-version: ['11.8.0', '12.2.0']
49+
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
50+
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
51+
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
52+
# when building without C++11 ABI and using it on nvcr images.
53+
cxx11_abi: ['FALSE', 'TRUE']
54+
exclude:
55+
# Pytorch <= 1.12 does not support Python 3.11
56+
- torch-version: '1.12.1'
57+
python-version: '3.11'
58+
# Pytorch >= 2.0 only supports Python >= 3.8
59+
- torch-version: '2.0.1'
60+
python-version: '3.7'
61+
- torch-version: '2.1.1'
62+
python-version: '3.7'
63+
- torch-version: '2.2.0.dev20231127'
64+
python-version: '3.7'
65+
# Pytorch <= 2.0 only supports CUDA <= 11.8
66+
- torch-version: '1.12.1'
67+
cuda-version: '12.2.0'
68+
- torch-version: '1.13.1'
69+
cuda-version: '12.2.0'
70+
- torch-version: '2.0.1'
71+
cuda-version: '12.2.0'
72+
73+
steps:
74+
- name: Checkout
75+
uses: actions/checkout@v3
76+
77+
- name: Set up Python
78+
uses: actions/setup-python@v4
79+
with:
80+
python-version: ${{ matrix.python-version }}
81+
82+
- name: Set CUDA and PyTorch versions
83+
run: |
84+
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
85+
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
86+
87+
- name: Free up disk space
88+
if: ${{ runner.os == 'Linux' }}
89+
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
90+
# https://github.com/easimon/maximize-build-space/tree/test-report
91+
run: |
92+
sudo rm -rf /usr/share/dotnet
93+
sudo rm -rf /opt/ghc
94+
sudo rm -rf /opt/hostedtoolcache/CodeQL
95+
96+
- name: Set up swap space
97+
if: runner.os == 'Linux'
98+
uses: pierotofy/set-swap-space@v1.0
99+
with:
100+
swap-size-gb: 10
101+
102+
- name: Install CUDA ${{ matrix.cuda-version }}
103+
if: ${{ matrix.cuda-version != 'cpu' }}
104+
uses: Jimver/cuda-toolkit@v0.2.11
105+
id: cuda-toolkit
106+
with:
107+
cuda: ${{ matrix.cuda-version }}
108+
linux-local-args: '["--toolkit"]'
109+
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
110+
# method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }}
111+
method: 'network'
112+
# We need the cuda libraries (e.g. cuSparse, cuSolver) for compiling PyTorch extensions,
113+
# not just nvcc
114+
# sub-packages: '["nvcc"]'
115+
116+
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
117+
run: |
118+
pip install --upgrade pip
119+
# If we don't install before installing Pytorch, we get error for torch 2.0.1
120+
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
121+
pip install lit
122+
# We want to figure out the CUDA version to download pytorch
123+
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
124+
# This code is ugly, maybe there's a better way to do this.
125+
export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
126+
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
127+
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
128+
else
129+
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
130+
fi
131+
nvcc --version
132+
python --version
133+
python -c "import torch; print('PyTorch:', torch.__version__)"
134+
python -c "import torch; print('CUDA:', torch.version.cuda)"
135+
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
136+
shell:
137+
bash
138+
139+
- name: Build wheel
140+
run: |
141+
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
142+
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
143+
# However this still fails so I'm using a newer version of setuptools
144+
pip install setuptools==68.0.0
145+
pip install ninja packaging wheel
146+
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
147+
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
148+
# Limit MAX_JOBS otherwise the github runner goes OOM
149+
MAX_JOBS=2 MAMBA_FORCE_BUILD="TRUE" MAMBA_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
150+
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
151+
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
152+
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
153+
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
154+
155+
- name: Log Built Wheels
156+
run: |
157+
ls dist
158+
159+
- name: Get the tag version
160+
id: extract_branch
161+
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
162+
163+
- name: Get Release with tag
164+
id: get_current_release
165+
uses: joutvhu/get-release@v1
166+
with:
167+
tag_name: ${{ steps.extract_branch.outputs.branch }}
168+
env:
169+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
170+
171+
- name: Upload Release Asset
172+
id: upload_release_asset
173+
uses: actions/upload-release-asset@v1
174+
env:
175+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
176+
with:
177+
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
178+
asset_path: ./dist/${{env.wheel_name}}
179+
asset_name: ${{env.wheel_name}}
180+
asset_content_type: application/*
181+
182+
publish_package:
183+
name: Publish package
184+
needs: [build_wheels]
185+
186+
runs-on: ubuntu-latest
187+
188+
steps:
189+
- uses: actions/checkout@v3
190+
191+
- uses: actions/setup-python@v4
192+
with:
193+
python-version: '3.10'
194+
195+
- name: Install dependencies
196+
run: |
197+
pip install ninja packaging setuptools wheel twine
198+
# We don't want to download anything CUDA-related here
199+
pip install torch --index-url https://download.pytorch.org/whl/cpu
200+
201+
- name: Build core package
202+
env:
203+
MAMBA_SKIP_CUDA_BUILD: "TRUE"
204+
run: |
205+
python setup.py sdist --dist-dir=dist
206+
207+
- name: Deploy
208+
env:
209+
TWINE_USERNAME: "__token__"
210+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
211+
run: |
212+
python -m twine upload dist/*

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "3rdparty/lm-evaluation-harness"]
2+
path = 3rdparty/lm-evaluation-harness
3+
url = https://github.com/EleutherAI/lm-evaluation-harness/

3rdparty/lm-evaluation-harness

Submodule lm-evaluation-harness added at a352061

AUTHORS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Tri Dao, tri@tridao.me
2+
Albert Gu, agu@andrew.cmu.edu

README.md

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,141 @@
11
# Mamba
22

3-
This repository contains the code for the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
3+
![Mamba](assets/selection.png "Selective State Space")
4+
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
5+
> Albert Gu*, Tri Dao*\
6+
> Paper: https://arxiv.org/abs/2312.00752
47
5-
The first official code release of the paper will be uploaded around noon EST, Monday Dec. 4.
8+
## Installation
9+
10+
- `pip install causal-conv1d`: an efficient implemention of a simple causal Conv1d layer used inside the Mamba block.
11+
- `pip install mamba-ssm`: the core Mamba package.
12+
13+
If `pip` complains about PyTorch versions, try passing `--no-build-isolation` to `pip`.
14+
15+
Other requirements:
16+
- Linux
17+
- NVIDIA GPU
18+
- PyTorch 1.12+
19+
- CUDA 11.6+
20+
21+
## Usage
22+
23+
We expose several levels of interface with the Mamba model.
24+
25+
### Selective SSM
26+
27+
Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).
28+
29+
Source: [ops/selective_scan_interface.py](mamba_ssm/ops/selective_scan_interface.py).
30+
31+
### Mamba Block
32+
33+
The main module of this repository is the Mamba architecture block wrapping the selective SSM.
34+
35+
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
36+
37+
Usage:
38+
```
39+
from mamba_ssm import Mamba
40+
41+
batch, length, dim = 2, 64, 16
42+
x = torch.randn(batch, length, dim).to("cuda")
43+
model = Mamba(
44+
# This module uses roughly 3 * expand * d_model^2 parameters
45+
d_model=dim, # Model dimension d_model
46+
d_state=16, # SSM state expansion factor
47+
d_conv=4, # Local convolution width
48+
expand=2, # Block expansion factor
49+
).to("cuda")
50+
y = model(x)
51+
assert y.shape == x.shape
52+
```
53+
54+
### Mamba Language Model
55+
56+
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
57+
58+
Source: [models/mixer_seq_simple.py](mamba_ssm/models/mixer_seq_simple.py).
59+
60+
This is an example of how to integrate Mamba into an end-to-end neural network.
61+
This example is used in the generation scripts below.
62+
63+
64+
65+
## Pretrained Models
66+
67+
Pretrained models are uploaded to
68+
[HuggingFace](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
69+
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`.
70+
71+
The models will be autodownloaded by the generation script below.
72+
73+
These models were trained on the [Pile](https://huggingface.co/datasets/EleutherAI/pile), and follow the standard model dimensions described by GPT-3 and followed by many open source models:
74+
75+
| Parameters | Layers | Model dim. |
76+
|------------|--------|------------|
77+
| 130M | 12 | 768 |
78+
| 370M | 24 | 1024 |
79+
| 790M | 24 | 1536 |
80+
| 1.4B | 24 | 2048 |
81+
| 2.8B | 32 | 2560 |
82+
83+
(The layer count of Mamba should be doubled, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)
84+
85+
Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.).
86+
Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.
87+
88+
89+
## Evaluations
90+
91+
To run zero-shot evaluations of models (corresponding to Table 3 of the paper),
92+
we use the
93+
[lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor)
94+
library.
95+
96+
1. Pull the `lm-evaluation-harness` repo by `git submodule update --init
97+
--recursive`. We use the `big-refactor` branch.
98+
2. Install `lm-evaluation-harness`: `pip install -e 3rdparty/lm-evaluation-harness`
99+
3. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
100+
```
101+
python evals/lm_harness_eval.py --model mamba --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
102+
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
103+
```
104+
105+
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
106+
107+
## Inference
108+
109+
The script [benchmarks/benchmark_generation_mamba_simple.py](benchmarks/benchmark_generation_mamba_simple.py)
110+
1. autoloads a model from the HuggingFace Hub,
111+
2. generates completions of a user-specified prompt,
112+
3. benchmarks the inference speed of this generation.
113+
114+
Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.
115+
116+
### Examples
117+
118+
To test generation latency (e.g. batch size = 1) with different sampling strategies:
119+
120+
```
121+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
122+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
123+
```
124+
125+
To test generation throughput with random prompts (e.g. large batch size):
126+
```
127+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
128+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
129+
```
130+
131+
## Citation
132+
133+
If you use this codebase, or otherwise found our work valuable, please cite Mamba:
134+
```
135+
@article{mamba,
136+
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
137+
author={Gu, Albert and Dao, Tri},
138+
journal={arXiv preprint arXiv:2312.00752},
139+
year={2023}
140+
}
141+
```

assets/selection.png

799 KB
Loading

0 commit comments

Comments
 (0)