6
6
7
7
from einops import rearrange , repeat
8
8
9
- from causal_conv1d import causal_conv1d_fn
10
- import causal_conv1d_cuda
9
+ try :
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError :
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
11
16
import selective_scan_cuda
12
17
13
18
@@ -163,6 +168,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
163
168
"""
164
169
xz: (batch, dim, seqlen)
165
170
"""
171
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
166
172
assert checkpoint_lvl in [0 , 1 ]
167
173
L = xz .shape [- 1 ]
168
174
delta_rank = delta_proj_weight .shape [1 ]
@@ -178,7 +184,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
178
184
conv1d_weight = rearrange (conv1d_weight , "d 1 w -> d w" )
179
185
x , z = xz .chunk (2 , dim = 1 )
180
186
conv1d_bias = conv1d_bias .contiguous () if conv1d_bias is not None else None
181
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
187
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
188
+ x , conv1d_weight , conv1d_bias , None , None , None , True
189
+ )
182
190
# We're being very careful here about the layout, to avoid extra transposes.
183
191
# We want delta to have d as the slowest moving dimension
184
192
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
@@ -231,6 +239,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
231
239
@custom_bwd
232
240
def backward (ctx , dout ):
233
241
# dout: (batch, seqlen, dim)
242
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
234
243
(xz , conv1d_weight , conv1d_bias , x_dbl , x_proj_weight , delta_proj_weight , out_proj_weight ,
235
244
conv1d_out , delta , A , B , C , D , delta_bias , scan_intermediates , out ) = ctx .saved_tensors
236
245
L = xz .shape [- 1 ]
@@ -240,7 +249,9 @@ def backward(ctx, dout):
240
249
if dout .stride (- 1 ) != 1 :
241
250
dout = dout .contiguous ()
242
251
if ctx .checkpoint_lvl == 1 :
243
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
252
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
253
+ x , conv1d_weight , conv1d_bias , None , None , None , True
254
+ )
244
255
delta = rearrange (delta_proj_weight @ x_dbl [:, :delta_rank ].t (),
245
256
"d (b l) -> b d l" , l = L )
246
257
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
@@ -285,8 +296,8 @@ def backward(ctx, dout):
285
296
dconv1d_out = rearrange (dconv1d_out , "d (b l) -> b d l" , b = x .shape [0 ], l = x .shape [- 1 ])
286
297
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
287
298
# backward of conv1d with the backward of chunk).
288
- dx , dconv1d_weight , dconv1d_bias = causal_conv1d_cuda .causal_conv1d_bwd (
289
- x , conv1d_weight , conv1d_bias , dconv1d_out , None , dx , True
299
+ dx , dconv1d_weight , dconv1d_bias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
300
+ x , conv1d_weight , conv1d_bias , dconv1d_out , None , None , None , dx , False , True
290
301
)
291
302
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
292
303
dconv1d_weight = rearrange (dconv1d_weight , "d w -> d 1 w" )
@@ -314,11 +325,12 @@ def mamba_inner_ref(
314
325
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
315
326
C_proj_bias = None , delta_softplus = True
316
327
):
328
+ assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
317
329
L = xz .shape [- 1 ]
318
330
delta_rank = delta_proj_weight .shape [1 ]
319
331
d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
320
332
x , z = xz .chunk (2 , dim = 1 )
321
- x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , None , "silu" )
333
+ x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , activation = "silu" )
322
334
# We're being very careful here about the layout, to avoid extra transposes.
323
335
# We want delta to have d as the slowest moving dimension
324
336
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
0 commit comments