Skip to content

Commit afd5fb5

Browse files
committed
Update causal-conv1d to 1.2.0, make it optional
1 parent 9583c56 commit afd5fb5

File tree

5 files changed

+24
-12
lines changed

5 files changed

+24
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
1313

1414
## Installation
1515

16-
- `pip install causal-conv1d>=1.1.0,<1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
16+
- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
1717
- `pip install mamba-ssm`: the core Mamba package.
1818

1919
It can also be built from source with `pip install .` from this repository.

mamba_ssm/modules/mamba_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
try:
1616
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
1717
except ImportError:
18-
causal_conv1d_fn, causal_conv1d_update = None
18+
causal_conv1d_fn, causal_conv1d_update = None, None
1919

2020
try:
2121
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
@@ -142,7 +142,7 @@ def forward(self, hidden_states, inference_params=None):
142142

143143
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
144144
# In the backward pass we write dx and dz next to each other to avoid torch.cat
145-
if self.use_fast_path and inference_params is None: # Doesn't support outputting the states
145+
if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states
146146
out = mamba_inner_fn(
147147
xz,
148148
self.conv1d.weight,

mamba_ssm/ops/selective_scan_interface.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66

77
from einops import rearrange, repeat
88

9-
from causal_conv1d import causal_conv1d_fn
10-
import causal_conv1d_cuda
9+
try:
10+
from causal_conv1d import causal_conv1d_fn
11+
import causal_conv1d_cuda
12+
except ImportError:
13+
causal_conv1d_fn = None
14+
causal_conv1d_cuda = None
15+
1116
import selective_scan_cuda
1217

1318

@@ -163,6 +168,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
163168
"""
164169
xz: (batch, dim, seqlen)
165170
"""
171+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
166172
assert checkpoint_lvl in [0, 1]
167173
L = xz.shape[-1]
168174
delta_rank = delta_proj_weight.shape[1]
@@ -178,7 +184,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
178184
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
179185
x, z = xz.chunk(2, dim=1)
180186
conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
181-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
187+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
188+
x, conv1d_weight, conv1d_bias, None, None, None, True
189+
)
182190
# We're being very careful here about the layout, to avoid extra transposes.
183191
# We want delta to have d as the slowest moving dimension
184192
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
@@ -231,6 +239,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
231239
@custom_bwd
232240
def backward(ctx, dout):
233241
# dout: (batch, seqlen, dim)
242+
assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
234243
(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
235244
conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
236245
L = xz.shape[-1]
@@ -240,7 +249,9 @@ def backward(ctx, dout):
240249
if dout.stride(-1) != 1:
241250
dout = dout.contiguous()
242251
if ctx.checkpoint_lvl == 1:
243-
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
252+
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
253+
x, conv1d_weight, conv1d_bias, None, None, None, True
254+
)
244255
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
245256
"d (b l) -> b d l", l = L)
246257
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
@@ -285,8 +296,8 @@ def backward(ctx, dout):
285296
dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
286297
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
287298
# backward of conv1d with the backward of chunk).
288-
dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
289-
x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
299+
dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
300+
x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
290301
)
291302
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
292303
dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
@@ -314,11 +325,12 @@ def mamba_inner_ref(
314325
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
315326
C_proj_bias=None, delta_softplus=True
316327
):
328+
assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
317329
L = xz.shape[-1]
318330
delta_rank = delta_proj_weight.shape[1]
319331
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
320332
x, z = xz.chunk(2, dim=1)
321-
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu")
333+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
322334
# We're being very careful here about the layout, to avoid extra transposes.
323335
# We want delta to have d as the slowest moving dimension
324336
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,6 @@ def run(self):
271271
"einops",
272272
"triton",
273273
"transformers",
274-
"causal_conv1d>=1.1.0,<1.2.0",
274+
# "causal_conv1d>=1.2.0",
275275
],
276276
)

tests/ops/triton/test_selective_state_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# @pytest.mark.parametrize("dstate", [16])
2020
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
2121
# @pytest.mark.parametrize("dim", [2048])
22-
def test_causal_conv1d_update(dim, dstate, has_z, itype):
22+
def test_selective_state_update(dim, dstate, has_z, itype):
2323
device = "cuda"
2424
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
2525
if itype == torch.bfloat16:

0 commit comments

Comments
 (0)