Skip to content

Commit 60dadf2

Browse files
committed
Mamba-2 code release
1 parent c59255a commit 60dadf2

21 files changed

+6707
-118
lines changed

README.md

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
> **Mamba: Linear-Time Sequence Modeling with Selective State Spaces**\
55
> Albert Gu*, Tri Dao*\
66
> Paper: https://arxiv.org/abs/2312.00752
7+
> **Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality**\
8+
> Tri Dao*, Albert Gu*\
9+
> Paper: https://arxiv.org/abs/2405.21060
710
811
## About
912

@@ -43,7 +46,7 @@ The main module of this repository is the Mamba architecture block wrapping the
4346
Source: [modules/mamba_simple.py](mamba_ssm/modules/mamba_simple.py).
4447

4548
Usage:
46-
```
49+
``` python
4750
import torch
4851
from mamba_ssm import Mamba
4952

@@ -60,6 +63,24 @@ y = model(x)
6063
assert y.shape == x.shape
6164
```
6265

66+
The Mamba-2 block is implemented at [modules/mamba2.py](mamba_ssm/modules/mamba2.py).
67+
68+
A simpler version is at [modules/mamba2_simple.py](mamba_ssm/modules/mamba2_simple.py)
69+
70+
The usage is similar to Mamba(-1):
71+
``` python
72+
from mamba_ssm import Mamba2
73+
model = Mamba(
74+
# This module uses roughly 3 * expand * d_model^2 parameters
75+
d_model=dim, # Model dimension d_model
76+
d_state=64, # SSM state expansion factor, typically 64 or 128
77+
d_conv=4, # Local convolution width
78+
expand=2, # Block expansion factor
79+
).to("cuda")
80+
y = model(x)
81+
assert y.shape == x.shape
82+
```
83+
6384
### Mamba Language Model
6485

6586
Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.
@@ -70,12 +91,12 @@ This is an example of how to integrate Mamba into an end-to-end neural network.
7091
This example is used in the generation scripts below.
7192

7293

73-
7494
## Pretrained Models
7595

7696
Pretrained models are uploaded to
7797
[Hugging Face](https://huggingface.co/state-spaces): `mamba-130m`, `mamba-370m`,
78-
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
98+
`mamba-790m`, `mamba-1.4b`, `mamba-2.8b`, `mamba2-130m`, `mamba2-370m`,
99+
`mamba2-780m`, `mamba2-1.3b`, `mamba2-2.7b`, `transformerpp-2.7b`, `mamba2attn-2.7b`, trained on 300B tokens on the Pile, as well as `mamba-2.8b-slimpj`
79100
(trained on 600B tokens on the SlimPajama dataset).
80101

81102

@@ -106,17 +127,24 @@ library.
106127

107128
1. Install `lm-evaluation-harness` by `pip install lm-eval==0.4.2`.
108129
2. Run evaluation with (more documentation at the [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/big-refactor) repo):
109-
```
130+
``` sh
110131
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
111132
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64
112133
```
113134

114135
To reproduce the results on the `mamba-2.8b-slimpj` model reported in the blogposts:
115-
```
136+
``` sh
116137
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
117138
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256
118139
```
119140

141+
To run evaluations on Mamba-2 models, simply replace the model names:
142+
``` sh
143+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
144+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
145+
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
146+
```
147+
120148
Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.
121149

122150
## Inference
@@ -132,16 +160,21 @@ Other configurable options include the top-p (nucleus sampling) probability, and
132160

133161
To test generation latency (e.g. batch size = 1) with different sampling strategies:
134162

135-
```
163+
``` sh
136164
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
137165
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
138166
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2
139167
```
140168

141169
To test generation throughput with random prompts (e.g. large batch size):
170+
``` sh
171+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
172+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64
142173
```
143-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 128
144-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 128
174+
175+
With Mamba-2, you just need to change the model name:
176+
``` sh
177+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
145178
```
146179

147180

@@ -164,12 +197,19 @@ that is specific to the training framework.
164197

165198
## Citation
166199

167-
If you use this codebase, or otherwise found our work valuable, please cite Mamba:
200+
If you use this codebase, or otherwise find our work valuable, please cite Mamba:
168201
```
169202
@article{mamba,
170203
title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
171204
author={Gu, Albert and Dao, Tri},
172205
journal={arXiv preprint arXiv:2312.00752},
173206
year={2023}
174207
}
208+
@inproceedings{mamba2,
209+
title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
210+
author={Dao, Tri and Gu, Albert},
211+
booktitle={International Conference on Machine Learning (ICML)},
212+
year={2024}
213+
}
214+
175215
```

benchmarks/benchmark_generation_mamba_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
dtype = torch.float16
3333

3434
print(f"Loading model {args.model_name}")
35-
is_mamba = args.model_name.startswith("state-spaces/mamba-")
35+
is_mamba = args.model_name.startswith("state-spaces/mamba") or args.model_name.startswith("state-spaces/transformerpp")
3636
if is_mamba:
3737
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
3838
model = MambaLMHeadModel.from_pretrained(args.model_name, device=device, dtype=dtype)

mamba_ssm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__version__ = "1.2.2"
1+
__version__ = "2.0.0"
22

33
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
44
from mamba_ssm.modules.mamba_simple import Mamba
5+
from mamba_ssm.modules.mamba2 import Mamba2
56
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

mamba_ssm/distributed/__init__.py

Whitespace-only changes.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import Tensor
5+
from torch.distributed import ProcessGroup
6+
7+
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
8+
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
9+
# version of PyTorch. The following 4 lines are for backward compatibility with
10+
# older PyTorch.
11+
if "all_gather_into_tensor" not in dir(torch.distributed):
12+
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
13+
if "reduce_scatter_tensor" not in dir(torch.distributed):
14+
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
15+
16+
17+
# Raw operation, does not support autograd, but does support async
18+
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
19+
world_size = torch.distributed.get_world_size(process_group)
20+
output = torch.empty(
21+
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
22+
)
23+
handle = torch.distributed.all_gather_into_tensor(
24+
output, input_.contiguous(), group=process_group, async_op=async_op
25+
)
26+
return output, handle
27+
28+
29+
# Raw operation, does not support autograd, but does support async
30+
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
31+
world_size = torch.distributed.get_world_size(process_group)
32+
assert input_.shape[0] % world_size == 0
33+
output = torch.empty(
34+
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
35+
)
36+
handle = torch.distributed.reduce_scatter_tensor(
37+
output, input_.contiguous(), group=process_group, async_op=async_op
38+
)
39+
return output, handle
40+
41+
42+
# Raw operation, does not support autograd, but does support async
43+
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
44+
input_ = input_.contiguous()
45+
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
46+
return input_, handle
47+
48+
49+
class AllGatherFunc(torch.autograd.Function):
50+
"""Gather the input from sequence parallel region and concatenate."""
51+
52+
@staticmethod
53+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
54+
ctx.process_group = process_group
55+
output, _ = all_gather_raw(input_, process_group)
56+
return output
57+
58+
@staticmethod
59+
def backward(ctx, grad_output: Tensor):
60+
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
61+
return grad_input, None
62+
63+
64+
# Supports autograd, but does not support async
65+
all_gather = AllGatherFunc.apply
66+
67+
68+
class ReduceScatterFunc(torch.autograd.Function):
69+
"""Reduce scatter the input from the sequence parallel region and concatenate."""
70+
71+
@staticmethod
72+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
73+
ctx.process_group = process_group
74+
output, _ = reduce_scatter_raw(input_, process_group)
75+
return output
76+
77+
@staticmethod
78+
def backward(ctx, grad_output: Tensor):
79+
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
80+
return grad_input, None
81+
82+
83+
# Supports autograd, but does not support async
84+
reduce_scatter = ReduceScatterFunc.apply
85+
86+
87+
class AllReduceFunc(torch.autograd.Function):
88+
"""Gather the input from sequence parallel region and concatenate."""
89+
90+
@staticmethod
91+
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
92+
ctx.process_group = process_group
93+
output, _ = all_reduce_raw(input_, process_group)
94+
return output
95+
96+
@staticmethod
97+
def backward(ctx, grad_output: Tensor):
98+
return grad_output, None
99+
100+
101+
# Supports autograd, but does not support async
102+
all_reduce = AllReduceFunc.apply
103+
104+
105+
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
106+
# We want to iterate over parameters with _shared_params=True in the same order,
107+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
108+
pamams_shared = {
109+
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
110+
}
111+
for _, p in sorted(pamams_shared.items()):
112+
with torch.no_grad():
113+
# Broadcast needs src to be global rank, not group rank
114+
torch.distributed.broadcast(
115+
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
116+
)
117+
118+
119+
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
120+
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
121+
# We want to iterate over parameters with _sequence_parallel=True in the same order,
122+
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
123+
params_seqparallel = {
124+
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
125+
}
126+
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
127+
if grads:
128+
with torch.no_grad():
129+
coalesced = torch._utils._flatten_dense_tensors(grads)
130+
torch.distributed.all_reduce(coalesced, group=process_group)
131+
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
132+
buf.copy_(synced)
133+
134+
135+
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
136+
"""Get the dim for the local rank derived from splitting dim on world_size processes.
137+
138+
The split may not be even across the world_size processes.
139+
"""
140+
multiple = dim // multiple_of
141+
div = multiple // world_size
142+
mod = multiple % world_size
143+
local_multiple = div + int(local_rank < mod)
144+
return local_multiple * multiple_of

0 commit comments

Comments
 (0)