Skip to content

Commit f40e2ee

Browse files
committed
Add k_activation.py
1 parent 33aeaa0 commit f40e2ee

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

mamba_ssm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.0.1"
1+
__version__ = "2.0.2"
22

33
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
44
from mamba_ssm.modules.mamba_simple import Mamba
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) 2024, Tri Dao, Albert Gu.
2+
3+
import torch
4+
5+
import triton
6+
import triton.language as tl
7+
8+
9+
@triton.autotune(
10+
configs=[
11+
triton.Config({'BLOCK_N': 32}),
12+
triton.Config({'BLOCK_N': 64}),
13+
triton.Config({'BLOCK_N': 128}),
14+
triton.Config({'BLOCK_N': 256}),
15+
triton.Config({'BLOCK_N': 512}),
16+
triton.Config({'BLOCK_N': 1024}),
17+
],
18+
key=['ncols'],
19+
)
20+
@triton.jit
21+
def _swiglu_fwd_kernel(
22+
X,
23+
Y,
24+
OUT,
25+
stride_x_row, # how much to increase the pointer when moving by 1 row
26+
stride_y_row,
27+
stride_out_row,
28+
ncols,
29+
BLOCK_N: tl.constexpr,
30+
):
31+
# Map the program id to the row of X and Y it should compute.
32+
row = tl.program_id(0)
33+
start_col = tl.program_id(1) * BLOCK_N
34+
X += row * stride_x_row
35+
Y += row * stride_y_row
36+
OUT += row * stride_out_row
37+
cols = start_col + tl.arange(0, BLOCK_N)
38+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
39+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
40+
out = x * tl.sigmoid(x) * y
41+
tl.store(OUT + cols, out, mask=cols < ncols)
42+
43+
44+
def _swiglu_fwd(xy, out=None):
45+
if xy.stride(-1) != 1:
46+
xy = xy.contiguous()
47+
batch_shape = xy.shape[:-1]
48+
xy = xy.reshape(-1, xy.shape[-1])
49+
x, y = xy.chunk(2, dim=-1)
50+
if out is None:
51+
out = torch.empty_like(x)
52+
else:
53+
out = out.reshape(-1, out.shape[-1])
54+
assert out.shape == x.shape
55+
assert out.stride(-1) == 1
56+
M, N = x.shape
57+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
58+
with torch.cuda.device(x.device.index):
59+
_swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)
60+
return out.reshape(*batch_shape, out.shape[-1])
61+
62+
63+
@triton.autotune(
64+
configs=[
65+
triton.Config({'BLOCK_N': 32}),
66+
triton.Config({'BLOCK_N': 64}),
67+
triton.Config({'BLOCK_N': 128}),
68+
triton.Config({'BLOCK_N': 256}),
69+
triton.Config({'BLOCK_N': 512}),
70+
triton.Config({'BLOCK_N': 1024}),
71+
],
72+
key=['ncols'],
73+
)
74+
@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["OUT"] is not None})
75+
@triton.jit
76+
def _swiglu_bwd_kernel(
77+
X,
78+
Y,
79+
DOUT,
80+
OUT,
81+
DX,
82+
DY,
83+
stride_x_row, # how much to increase the pointer when moving by 1 row
84+
stride_y_row,
85+
stride_dout_row,
86+
stride_out_row,
87+
stride_dx_row,
88+
stride_dy_row,
89+
ncols,
90+
BLOCK_N: tl.constexpr,
91+
RECOMPUTE_OUTPUT: tl.constexpr,
92+
):
93+
# Map the program id to the row of X and Y it should compute.
94+
row = tl.program_id(0)
95+
start_col = tl.program_id(1) * BLOCK_N
96+
X += row * stride_x_row
97+
Y += row * stride_y_row
98+
DOUT += row * stride_dout_row
99+
if RECOMPUTE_OUTPUT:
100+
OUT += row * stride_out_row
101+
DX += row * stride_dx_row
102+
DY += row * stride_dy_row
103+
cols = start_col + tl.arange(0, BLOCK_N)
104+
x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)
105+
y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)
106+
dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)
107+
x_sigmoid = tl.sigmoid(x)
108+
dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout
109+
dy = x * x_sigmoid * dout
110+
tl.store(DX + cols, dx, mask=cols < ncols)
111+
tl.store(DY + cols, dy, mask=cols < ncols)
112+
if RECOMPUTE_OUTPUT:
113+
out = x * x_sigmoid * y
114+
tl.store(OUT + cols, out, mask=cols < ncols)
115+
116+
117+
def _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):
118+
if xy.stride(-1) != 1:
119+
xy = xy.contiguous()
120+
if dout.stride(-1) != 1:
121+
dout = dout.contiguous()
122+
batch_shape = xy.shape[:-1]
123+
xy = xy.reshape(-1, xy.shape[-1])
124+
x, y = xy.chunk(2, dim=-1)
125+
dout = dout.reshape(-1, dout.shape[-1])
126+
assert dout.shape == x.shape
127+
if dxy is None:
128+
dxy = torch.empty_like(xy)
129+
else:
130+
dxy = dxy.reshape(-1, dxy.shape[-1])
131+
assert dxy.shape == xy.shape
132+
dx, dy = dxy.chunk(2, dim=-1)
133+
assert dx.stride(-1) == 1
134+
assert dy.stride(-1) == 1
135+
if recompute_output:
136+
if out is None:
137+
out = torch.empty_like(x)
138+
else:
139+
out = out.reshape(-1, out.shape[-1])
140+
assert out.shape == x.shape
141+
assert out.stride(-1) == 1
142+
M, N = x.shape
143+
grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))
144+
with torch.cuda.device(x.device.index):
145+
_swiglu_bwd_kernel[grid](x, y, dout, out if recompute_output else None, dx, dy,
146+
x.stride(0), y.stride(0), dout.stride(0),
147+
out.stride(0) if recompute_output else 0,
148+
dx.stride(0), dy.stride(0),
149+
N)
150+
if not recompute_output:
151+
return dxy.reshape(*batch_shape, dxy.shape[-1])
152+
else:
153+
return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])

0 commit comments

Comments
 (0)