Skip to content

Commit 2066c21

Browse files
authored
Bugfix causal_conv1d_fn interface (#168)
1 parent 5aa131d commit 2066c21

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/ops/selective_scan_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def mamba_inner_ref(
318318
delta_rank = delta_proj_weight.shape[1]
319319
d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
320320
x, z = xz.chunk(2, dim=1)
321-
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
321+
x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu")
322322
# We're being very careful here about the layout, to avoid extra transposes.
323323
# We want delta to have d as the slowest moving dimension
324324
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.

0 commit comments

Comments
 (0)