Skip to content

Commit 28b1435

Browse files
committed
Support nheads and ngroups in inference kernel
1 parent 12d8550 commit 28b1435

File tree

1 file changed

+149
-78
lines changed

1 file changed

+149
-78
lines changed
Lines changed: 149 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# Copyright (c) 2023, Tri Dao.
1+
# Copyright (c) 2024, Tri Dao, Albert Gu.
22

3-
"""We want triton==2.1.0 for this
3+
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
44
"""
55

66
import math
@@ -22,20 +22,21 @@ def _selective_scan_update_kernel(
2222
# Pointers to matrices
2323
state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr,
2424
# Matrix dimensions
25-
batch, dim, dstate,
25+
batch, nheads, dim, dstate, nheads_ngroups_ratio,
2626
# Strides
27-
stride_state_batch, stride_state_dim, stride_state_dstate,
28-
stride_x_batch, stride_x_dim,
29-
stride_dt_batch, stride_dt_dim,
30-
stride_dt_bias_dim,
31-
stride_A_dim, stride_A_dstate,
32-
stride_B_batch, stride_B_dstate,
33-
stride_C_batch, stride_C_dstate,
34-
stride_D_dim,
35-
stride_z_batch, stride_z_dim,
36-
stride_out_batch, stride_out_dim,
27+
stride_state_batch, stride_state_head, stride_state_dim, stride_state_dstate,
28+
stride_x_batch, stride_x_head, stride_x_dim,
29+
stride_dt_batch, stride_dt_head, stride_dt_dim,
30+
stride_dt_bias_head, stride_dt_bias_dim,
31+
stride_A_head, stride_A_dim, stride_A_dstate,
32+
stride_B_batch, stride_B_group, stride_B_dstate,
33+
stride_C_batch, stride_C_group, stride_C_dstate,
34+
stride_D_head, stride_D_dim,
35+
stride_z_batch, stride_z_head, stride_z_dim,
36+
stride_out_batch, stride_out_head, stride_out_dim,
3737
# Meta-parameters
3838
DT_SOFTPLUS: tl.constexpr,
39+
TIE_HDIM: tl.constexpr,
3940
BLOCK_SIZE_M: tl.constexpr,
4041
HAS_DT_BIAS: tl.constexpr,
4142
HAS_D: tl.constexpr,
@@ -44,14 +45,18 @@ def _selective_scan_update_kernel(
4445
):
4546
pid_m = tl.program_id(axis=0)
4647
pid_b = tl.program_id(axis=1)
47-
state_ptr += pid_b * stride_state_batch
48-
x_ptr += pid_b * stride_x_batch
49-
dt_ptr += pid_b * stride_dt_batch
50-
B_ptr += pid_b * stride_B_batch
51-
C_ptr += pid_b * stride_C_batch
48+
pid_h = tl.program_id(axis=2)
49+
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
50+
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
51+
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
52+
if HAS_DT_BIAS:
53+
dt_bias_ptr += pid_h * stride_dt_bias_head
54+
A_ptr += pid_h * stride_A_head
55+
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
56+
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
5257
if HAS_Z:
53-
z_ptr += pid_b * stride_z_batch
54-
out_ptr += pid_b * stride_out_batch
58+
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
59+
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
5560

5661
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
5762
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
@@ -60,6 +65,8 @@ def _selective_scan_update_kernel(
6065
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
6166
if HAS_DT_BIAS:
6267
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
68+
if HAS_D:
69+
D_ptr += pid_h * stride_D_head
6370
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate)
6471
B_ptrs = B_ptr + offs_n * stride_B_dstate
6572
C_ptrs = C_ptr + offs_n * stride_C_dstate
@@ -71,21 +78,34 @@ def _selective_scan_update_kernel(
7178

7279
state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0)
7380
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
74-
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
75-
if HAS_DT_BIAS:
76-
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
77-
if DT_SOFTPLUS:
78-
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
79-
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
80-
dA = tl.exp(A * dt[:, None])
81+
if not TIE_HDIM:
82+
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
83+
if HAS_DT_BIAS:
84+
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85+
if DT_SOFTPLUS:
86+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
87+
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
88+
dA = tl.exp(A * dt[:, None])
89+
else:
90+
dt = tl.load(dt_ptr).to(tl.float32)
91+
if HAS_DT_BIAS:
92+
dt += tl.load(dt_bias_ptr).to(tl.float32)
93+
if DT_SOFTPLUS:
94+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
95+
A = tl.load(A_ptr).to(tl.float32)
96+
dA = tl.exp(A * dt) # scalar, not a matrix
97+
8198
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
8299
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
83100
if HAS_D:
84101
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
85102
if HAS_Z:
86103
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
87104

88-
dB = B[None, :] * dt[:, None]
105+
if not TIE_HDIM:
106+
dB = B[None, :] * dt[:, None]
107+
else:
108+
dB = B * dt # vector of size (dstate,)
89109
state = state * dA + dB * x[:, None]
90110
tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
91111
out = tl.sum(state * C[None, :], axis=1)
@@ -99,94 +119,145 @@ def _selective_scan_update_kernel(
99119
def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
100120
"""
101121
Argument:
102-
state: (batch, dim, dstate)
103-
x: (batch, dim)
104-
dt: (batch, dim)
105-
A: (dim, dstate)
106-
B: (batch, dstate)
107-
C: (batch, dstate)
108-
D: (dim,)
109-
z: (batch, dim)
110-
dt_bias: (dim,)
122+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
123+
x: (batch, dim) or (batch, nheads, dim)
124+
dt: (batch, dim) or (batch, nheads, dim)
125+
A: (dim, dstate) or (nheads, dim, dstate)
126+
B: (batch, dstate) or (batch, ngroups, dstate)
127+
C: (batch, dstate) or (batch, ngroups, dstate)
128+
D: (dim,) or (nheads, dim)
129+
z: (batch, dim) or (batch, nheads, dim)
130+
dt_bias: (dim,) or (nheads, dim)
111131
Return:
112-
out: (batch, dim)
132+
out: (batch, dim) or (batch, nheads, dim)
113133
"""
114-
batch, dim, dstate = state.shape
115-
assert x.shape == (batch, dim)
134+
has_heads = state.dim() > 3
135+
if state.dim() == 3:
136+
state = state.unsqueeze(1)
137+
if x.dim() == 2:
138+
x = x.unsqueeze(1)
139+
if dt.dim() == 2:
140+
dt = dt.unsqueeze(1)
141+
if A.dim() == 2:
142+
A = A.unsqueeze(0)
143+
if B.dim() == 2:
144+
B = B.unsqueeze(1)
145+
if C.dim() == 2:
146+
C = C.unsqueeze(1)
147+
if D is not None and D.dim() == 1:
148+
D = D.unsqueeze(0)
149+
if z is not None and z.dim() == 2:
150+
z = z.unsqueeze(1)
151+
if dt_bias is not None and dt_bias.dim() == 1:
152+
dt_bias = dt_bias.unsqueeze(0)
153+
batch, nheads, dim, dstate = state.shape
154+
assert x.shape == (batch, nheads, dim)
116155
assert dt.shape == x.shape
117-
assert A.shape == (dim, dstate)
118-
assert B.shape == (batch, dstate)
156+
assert A.shape == (nheads, dim, dstate)
157+
ngroups = B.shape[1]
158+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
159+
assert B.shape == (batch, ngroups, dstate)
119160
assert C.shape == B.shape
120161
if D is not None:
121-
assert D.shape == (dim,)
162+
assert D.shape == (nheads, dim)
122163
if z is not None:
123164
assert z.shape == x.shape
124165
if dt_bias is not None:
125-
assert dt_bias.shape == (dim,)
166+
assert dt_bias.shape == (nheads, dim)
126167
out = torch.empty_like(x)
127-
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch)
128-
z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0))
168+
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
169+
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0))
129170
# We don't want autotune since it will overwrite the state
130171
# We instead tune by hand.
131172
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16
132173
else ((16, 4) if dstate <= 32 else
133174
((8, 4) if dstate <= 64 else
134175
((4, 4) if dstate <= 128 else
135176
((4, 8))))))
177+
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0
136178
with torch.cuda.device(x.device.index):
137179
_selective_scan_update_kernel[grid](
138180
state, x, dt, dt_bias, A, B, C, D, z, out,
139-
batch, dim, dstate,
140-
state.stride(0), state.stride(1), state.stride(2),
141-
x.stride(0), x.stride(1),
142-
dt.stride(0), dt.stride(1),
143-
dt_bias.stride(0) if dt_bias is not None else 0,
144-
A.stride(0), A.stride(1),
145-
B.stride(0), B.stride(1),
146-
C.stride(0), C.stride(1),
147-
D.stride(0) if D is not None else 0,
148-
z_strides[0], z_strides[1],
149-
out.stride(0), out.stride(1),
181+
batch, nheads, dim, dstate, nheads // ngroups,
182+
state.stride(0), state.stride(1), state.stride(2), state.stride(3),
183+
x.stride(0), x.stride(1), x.stride(2),
184+
dt.stride(0), dt.stride(1), dt.stride(2),
185+
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
186+
A.stride(0), A.stride(1), A.stride(2),
187+
B.stride(0), B.stride(1), B.stride(2),
188+
C.stride(0), C.stride(1), C.stride(2),
189+
*(D.stride(0), D.stride(1)) if D is not None else 0,
190+
z_strides[0], z_strides[1], z_strides[2],
191+
out.stride(0), out.stride(1), out.stride(2),
150192
dt_softplus,
193+
tie_hdim,
151194
BLOCK_SIZE_M,
152195
num_warps=num_warps,
153196
)
197+
if not has_heads:
198+
out = out.squeeze(1)
154199
return out
155200

156201

157202
def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False):
158203
"""
159204
Argument:
160-
state: (batch, dim, dstate)
161-
x: (batch, dim)
162-
dt: (batch, dim)
163-
A: (dim, dstate)
164-
B: (batch, dstate)
165-
C: (batch, dstate)
166-
D: (dim,)
167-
z: (batch, dim)
168-
dt_bias: (dim,)
205+
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
206+
x: (batch, dim) or (batch, nheads, dim)
207+
dt: (batch, dim) or (batch, nheads, dim)
208+
A: (dim, dstate) or (nheads, dim, dstate)
209+
B: (batch, dstate) or (batch, ngroups, dstate)
210+
C: (batch, dstate) or (batch, ngroups, dstate)
211+
D: (dim,) or (nheads, dim)
212+
z: (batch, dim) or (batch, nheads, dim)
213+
dt_bias: (dim,) or (nheads, dim)
169214
Return:
170-
out: (batch, dim)
215+
out: (batch, dim) or (batch, nheads, dim)
171216
"""
172-
batch, dim, dstate = state.shape
173-
assert x.shape == (batch, dim)
217+
has_heads = state.dim() > 3
218+
if state.dim() == 3:
219+
state = state.unsqueeze(1)
220+
if x.dim() == 2:
221+
x = x.unsqueeze(1)
222+
if dt.dim() == 2:
223+
dt = dt.unsqueeze(1)
224+
if A.dim() == 2:
225+
A = A.unsqueeze(0)
226+
if B.dim() == 2:
227+
B = B.unsqueeze(1)
228+
if C.dim() == 2:
229+
C = C.unsqueeze(1)
230+
if D is not None and D.dim() == 1:
231+
D = D.unsqueeze(0)
232+
if z is not None and z.dim() == 2:
233+
z = z.unsqueeze(1)
234+
if dt_bias is not None and dt_bias.dim() == 1:
235+
dt_bias = dt_bias.unsqueeze(0)
236+
batch, nheads, dim, dstate = state.shape
237+
assert x.shape == (batch, nheads, dim)
174238
assert dt.shape == x.shape
175-
assert A.shape == (dim, dstate)
176-
assert B.shape == (batch, dstate)
239+
assert A.shape == (nheads, dim, dstate)
240+
ngroups = B.shape[1]
241+
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
242+
assert B.shape == (batch, ngroups, dstate)
177243
assert C.shape == B.shape
178244
if D is not None:
179-
assert D.shape == (dim,)
245+
assert D.shape == (nheads, dim)
180246
if z is not None:
181247
assert z.shape == x.shape
182248
if dt_bias is not None:
183-
assert dt_bias.shape == (dim,)
249+
assert dt_bias.shape == (nheads, dim)
184250
dt = dt + dt_bias
185251
dt = F.softplus(dt) if dt_softplus else dt
186-
dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate)
187-
dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate)
188-
state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate
189-
out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C)
252+
dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * A) # (batch, nheads, dim, dstate)
253+
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
254+
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
255+
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
256+
state.copy_(state * dA + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
257+
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
190258
if D is not None:
191259
out += (x * D).to(out.dtype)
192-
return (out if z is None else out * F.silu(z)).to(x.dtype)
260+
out = (out if z is None else out * F.silu(z)).to(x.dtype)
261+
if not has_heads:
262+
out = out.squeeze(1)
263+
return out

0 commit comments

Comments
 (0)