1
- # Copyright (c) 2023 , Tri Dao.
1
+ # Copyright (c) 2024 , Tri Dao, Albert Gu .
2
2
3
- """We want triton==2.1.0 for this
3
+ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
4
4
"""
5
5
6
6
import math
@@ -22,20 +22,21 @@ def _selective_scan_update_kernel(
22
22
# Pointers to matrices
23
23
state_ptr , x_ptr , dt_ptr , dt_bias_ptr , A_ptr , B_ptr , C_ptr , D_ptr , z_ptr , out_ptr ,
24
24
# Matrix dimensions
25
- batch , dim , dstate ,
25
+ batch , nheads , dim , dstate , nheads_ngroups_ratio ,
26
26
# Strides
27
- stride_state_batch , stride_state_dim , stride_state_dstate ,
28
- stride_x_batch , stride_x_dim ,
29
- stride_dt_batch , stride_dt_dim ,
30
- stride_dt_bias_dim ,
31
- stride_A_dim , stride_A_dstate ,
32
- stride_B_batch , stride_B_dstate ,
33
- stride_C_batch , stride_C_dstate ,
34
- stride_D_dim ,
35
- stride_z_batch , stride_z_dim ,
36
- stride_out_batch , stride_out_dim ,
27
+ stride_state_batch , stride_state_head , stride_state_dim , stride_state_dstate ,
28
+ stride_x_batch , stride_x_head , stride_x_dim ,
29
+ stride_dt_batch , stride_dt_head , stride_dt_dim ,
30
+ stride_dt_bias_head , stride_dt_bias_dim ,
31
+ stride_A_head , stride_A_dim , stride_A_dstate ,
32
+ stride_B_batch , stride_B_group , stride_B_dstate ,
33
+ stride_C_batch , stride_C_group , stride_C_dstate ,
34
+ stride_D_head , stride_D_dim ,
35
+ stride_z_batch , stride_z_head , stride_z_dim ,
36
+ stride_out_batch , stride_out_head , stride_out_dim ,
37
37
# Meta-parameters
38
38
DT_SOFTPLUS : tl .constexpr ,
39
+ TIE_HDIM : tl .constexpr ,
39
40
BLOCK_SIZE_M : tl .constexpr ,
40
41
HAS_DT_BIAS : tl .constexpr ,
41
42
HAS_D : tl .constexpr ,
@@ -44,14 +45,18 @@ def _selective_scan_update_kernel(
44
45
):
45
46
pid_m = tl .program_id (axis = 0 )
46
47
pid_b = tl .program_id (axis = 1 )
47
- state_ptr += pid_b * stride_state_batch
48
- x_ptr += pid_b * stride_x_batch
49
- dt_ptr += pid_b * stride_dt_batch
50
- B_ptr += pid_b * stride_B_batch
51
- C_ptr += pid_b * stride_C_batch
48
+ pid_h = tl .program_id (axis = 2 )
49
+ state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
50
+ x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
51
+ dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
52
+ if HAS_DT_BIAS :
53
+ dt_bias_ptr += pid_h * stride_dt_bias_head
54
+ A_ptr += pid_h * stride_A_head
55
+ B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio ) * stride_B_group
56
+ C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio ) * stride_C_group
52
57
if HAS_Z :
53
- z_ptr += pid_b * stride_z_batch
54
- out_ptr += pid_b * stride_out_batch
58
+ z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
59
+ out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
55
60
56
61
offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
57
62
offs_n = tl .arange (0 , BLOCK_SIZE_DSTATE )
@@ -60,6 +65,8 @@ def _selective_scan_update_kernel(
60
65
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
61
66
if HAS_DT_BIAS :
62
67
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
68
+ if HAS_D :
69
+ D_ptr += pid_h * stride_D_head
63
70
A_ptrs = A_ptr + (offs_m [:, None ] * stride_A_dim + offs_n [None , :] * stride_A_dstate )
64
71
B_ptrs = B_ptr + offs_n * stride_B_dstate
65
72
C_ptrs = C_ptr + offs_n * stride_C_dstate
@@ -71,21 +78,34 @@ def _selective_scan_update_kernel(
71
78
72
79
state = tl .load (state_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 )
73
80
x = tl .load (x_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
74
- dt = tl .load (dt_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
75
- if HAS_DT_BIAS :
76
- dt += tl .load (dt_bias_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
77
- if DT_SOFTPLUS :
78
- dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
79
- A = tl .load (A_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 ).to (tl .float32 )
80
- dA = tl .exp (A * dt [:, None ])
81
+ if not TIE_HDIM :
82
+ dt = tl .load (dt_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
83
+ if HAS_DT_BIAS :
84
+ dt += tl .load (dt_bias_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
85
+ if DT_SOFTPLUS :
86
+ dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
87
+ A = tl .load (A_ptrs , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ), other = 0.0 ).to (tl .float32 )
88
+ dA = tl .exp (A * dt [:, None ])
89
+ else :
90
+ dt = tl .load (dt_ptr ).to (tl .float32 )
91
+ if HAS_DT_BIAS :
92
+ dt += tl .load (dt_bias_ptr ).to (tl .float32 )
93
+ if DT_SOFTPLUS :
94
+ dt = tl .where (dt <= 20.0 , tl .math .log1p (tl .exp (dt )), dt )
95
+ A = tl .load (A_ptr ).to (tl .float32 )
96
+ dA = tl .exp (A * dt ) # scalar, not a matrix
97
+
81
98
B = tl .load (B_ptrs , mask = offs_n < dstate , other = 0.0 ).to (tl .float32 )
82
99
C = tl .load (C_ptrs , mask = offs_n < dstate , other = 0.0 ).to (tl .float32 )
83
100
if HAS_D :
84
101
D = tl .load (D_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
85
102
if HAS_Z :
86
103
z = tl .load (z_ptrs , mask = offs_m < dim , other = 0.0 ).to (tl .float32 )
87
104
88
- dB = B [None , :] * dt [:, None ]
105
+ if not TIE_HDIM :
106
+ dB = B [None , :] * dt [:, None ]
107
+ else :
108
+ dB = B * dt # vector of size (dstate,)
89
109
state = state * dA + dB * x [:, None ]
90
110
tl .store (state_ptrs , state , mask = (offs_m [:, None ] < dim ) & (offs_n [None , :] < dstate ))
91
111
out = tl .sum (state * C [None , :], axis = 1 )
@@ -99,94 +119,145 @@ def _selective_scan_update_kernel(
99
119
def selective_state_update (state , x , dt , A , B , C , D = None , z = None , dt_bias = None , dt_softplus = False ):
100
120
"""
101
121
Argument:
102
- state: (batch, dim, dstate)
103
- x: (batch, dim)
104
- dt: (batch, dim)
105
- A: (dim, dstate)
106
- B: (batch, dstate)
107
- C: (batch, dstate)
108
- D: (dim,)
109
- z: (batch, dim)
110
- dt_bias: (dim,)
122
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
123
+ x: (batch, dim) or (batch, nheads, dim)
124
+ dt: (batch, dim) or (batch, nheads, dim)
125
+ A: (dim, dstate) or (nheads, dim, dstate)
126
+ B: (batch, dstate) or (batch, ngroups, dstate)
127
+ C: (batch, dstate) or (batch, ngroups, dstate)
128
+ D: (dim,) or (nheads, dim)
129
+ z: (batch, dim) or (batch, nheads, dim)
130
+ dt_bias: (dim,) or (nheads, dim)
111
131
Return:
112
- out: (batch, dim)
132
+ out: (batch, dim) or (batch, nheads, dim)
113
133
"""
114
- batch , dim , dstate = state .shape
115
- assert x .shape == (batch , dim )
134
+ has_heads = state .dim () > 3
135
+ if state .dim () == 3 :
136
+ state = state .unsqueeze (1 )
137
+ if x .dim () == 2 :
138
+ x = x .unsqueeze (1 )
139
+ if dt .dim () == 2 :
140
+ dt = dt .unsqueeze (1 )
141
+ if A .dim () == 2 :
142
+ A = A .unsqueeze (0 )
143
+ if B .dim () == 2 :
144
+ B = B .unsqueeze (1 )
145
+ if C .dim () == 2 :
146
+ C = C .unsqueeze (1 )
147
+ if D is not None and D .dim () == 1 :
148
+ D = D .unsqueeze (0 )
149
+ if z is not None and z .dim () == 2 :
150
+ z = z .unsqueeze (1 )
151
+ if dt_bias is not None and dt_bias .dim () == 1 :
152
+ dt_bias = dt_bias .unsqueeze (0 )
153
+ batch , nheads , dim , dstate = state .shape
154
+ assert x .shape == (batch , nheads , dim )
116
155
assert dt .shape == x .shape
117
- assert A .shape == (dim , dstate )
118
- assert B .shape == (batch , dstate )
156
+ assert A .shape == (nheads , dim , dstate )
157
+ ngroups = B .shape [1 ]
158
+ assert nheads % ngroups == 0 , "nheads must be divisible by ngroups"
159
+ assert B .shape == (batch , ngroups , dstate )
119
160
assert C .shape == B .shape
120
161
if D is not None :
121
- assert D .shape == (dim , )
162
+ assert D .shape == (nheads , dim )
122
163
if z is not None :
123
164
assert z .shape == x .shape
124
165
if dt_bias is not None :
125
- assert dt_bias .shape == (dim , )
166
+ assert dt_bias .shape == (nheads , dim )
126
167
out = torch .empty_like (x )
127
- grid = lambda META : (triton .cdiv (dim , META ['BLOCK_SIZE_M' ]), batch )
128
- z_strides = ((z .stride (0 ), z .stride (1 )) if z is not None else (0 , 0 ))
168
+ grid = lambda META : (triton .cdiv (dim , META ['BLOCK_SIZE_M' ]), batch , nheads )
169
+ z_strides = ((z .stride (0 ), z .stride (1 ), z . stride ( 2 )) if z is not None else (0 , 0 , 0 ))
129
170
# We don't want autotune since it will overwrite the state
130
171
# We instead tune by hand.
131
172
BLOCK_SIZE_M , num_warps = ((32 , 4 ) if dstate <= 16
132
173
else ((16 , 4 ) if dstate <= 32 else
133
174
((8 , 4 ) if dstate <= 64 else
134
175
((4 , 4 ) if dstate <= 128 else
135
176
((4 , 8 ))))))
177
+ tie_hdim = A .stride (- 1 ) == 0 and A .stride (- 2 ) == 0 and dt .stride (- 1 ) == 0 and dt_bias .stride (- 1 ) == 0
136
178
with torch .cuda .device (x .device .index ):
137
179
_selective_scan_update_kernel [grid ](
138
180
state , x , dt , dt_bias , A , B , C , D , z , out ,
139
- batch , dim , dstate ,
140
- state .stride (0 ), state .stride (1 ), state .stride (2 ),
141
- x .stride (0 ), x .stride (1 ),
142
- dt .stride (0 ), dt .stride (1 ),
143
- dt_bias .stride (0 ) if dt_bias is not None else 0 ,
144
- A .stride (0 ), A .stride (1 ),
145
- B .stride (0 ), B .stride (1 ),
146
- C .stride (0 ), C .stride (1 ),
147
- D .stride (0 ) if D is not None else 0 ,
148
- z_strides [0 ], z_strides [1 ],
149
- out .stride (0 ), out .stride (1 ),
181
+ batch , nheads , dim , dstate , nheads // ngroups ,
182
+ state .stride (0 ), state .stride (1 ), state .stride (2 ), state . stride ( 3 ),
183
+ x .stride (0 ), x .stride (1 ), x . stride ( 2 ),
184
+ dt .stride (0 ), dt .stride (1 ), dt . stride ( 2 ),
185
+ * ( dt_bias .stride (0 ), dt_bias . stride ( 1 ) ) if dt_bias is not None else 0 ,
186
+ A .stride (0 ), A .stride (1 ), A . stride ( 2 ),
187
+ B .stride (0 ), B .stride (1 ), B . stride ( 2 ),
188
+ C .stride (0 ), C .stride (1 ), C . stride ( 2 ),
189
+ * ( D .stride (0 ), D . stride ( 1 ) ) if D is not None else 0 ,
190
+ z_strides [0 ], z_strides [1 ], z_strides [ 2 ],
191
+ out .stride (0 ), out .stride (1 ), out . stride ( 2 ),
150
192
dt_softplus ,
193
+ tie_hdim ,
151
194
BLOCK_SIZE_M ,
152
195
num_warps = num_warps ,
153
196
)
197
+ if not has_heads :
198
+ out = out .squeeze (1 )
154
199
return out
155
200
156
201
157
202
def selective_state_update_ref (state , x , dt , A , B , C , D = None , z = None , dt_bias = None , dt_softplus = False ):
158
203
"""
159
204
Argument:
160
- state: (batch, dim, dstate)
161
- x: (batch, dim)
162
- dt: (batch, dim)
163
- A: (dim, dstate)
164
- B: (batch, dstate)
165
- C: (batch, dstate)
166
- D: (dim,)
167
- z: (batch, dim)
168
- dt_bias: (dim,)
205
+ state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
206
+ x: (batch, dim) or (batch, nheads, dim)
207
+ dt: (batch, dim) or (batch, nheads, dim)
208
+ A: (dim, dstate) or (nheads, dim, dstate)
209
+ B: (batch, dstate) or (batch, ngroups, dstate)
210
+ C: (batch, dstate) or (batch, ngroups, dstate)
211
+ D: (dim,) or (nheads, dim)
212
+ z: (batch, dim) or (batch, nheads, dim)
213
+ dt_bias: (dim,) or (nheads, dim)
169
214
Return:
170
- out: (batch, dim)
215
+ out: (batch, dim) or (batch, nheads, dim)
171
216
"""
172
- batch , dim , dstate = state .shape
173
- assert x .shape == (batch , dim )
217
+ has_heads = state .dim () > 3
218
+ if state .dim () == 3 :
219
+ state = state .unsqueeze (1 )
220
+ if x .dim () == 2 :
221
+ x = x .unsqueeze (1 )
222
+ if dt .dim () == 2 :
223
+ dt = dt .unsqueeze (1 )
224
+ if A .dim () == 2 :
225
+ A = A .unsqueeze (0 )
226
+ if B .dim () == 2 :
227
+ B = B .unsqueeze (1 )
228
+ if C .dim () == 2 :
229
+ C = C .unsqueeze (1 )
230
+ if D is not None and D .dim () == 1 :
231
+ D = D .unsqueeze (0 )
232
+ if z is not None and z .dim () == 2 :
233
+ z = z .unsqueeze (1 )
234
+ if dt_bias is not None and dt_bias .dim () == 1 :
235
+ dt_bias = dt_bias .unsqueeze (0 )
236
+ batch , nheads , dim , dstate = state .shape
237
+ assert x .shape == (batch , nheads , dim )
174
238
assert dt .shape == x .shape
175
- assert A .shape == (dim , dstate )
176
- assert B .shape == (batch , dstate )
239
+ assert A .shape == (nheads , dim , dstate )
240
+ ngroups = B .shape [1 ]
241
+ assert nheads % ngroups == 0 , "nheads must be divisible by ngroups"
242
+ assert B .shape == (batch , ngroups , dstate )
177
243
assert C .shape == B .shape
178
244
if D is not None :
179
- assert D .shape == (dim , )
245
+ assert D .shape == (nheads , dim )
180
246
if z is not None :
181
247
assert z .shape == x .shape
182
248
if dt_bias is not None :
183
- assert dt_bias .shape == (dim , )
249
+ assert dt_bias .shape == (nheads , dim )
184
250
dt = dt + dt_bias
185
251
dt = F .softplus (dt ) if dt_softplus else dt
186
- dA = torch .exp (rearrange (dt , "b d -> b d 1" ) * A ) # (batch, dim, dstate)
187
- dB = rearrange (dt , "b d -> b d 1" ) * rearrange (B , "b n -> b 1 n" ) # (batch, dim, dstate)
188
- state .copy_ (state * dA + dB * rearrange (x , "b d -> b d 1" )) # (batch, dim, dstate
189
- out = torch .einsum ("bdn,bn->bd" , state .to (C .dtype ), C )
252
+ dA = torch .exp (rearrange (dt , "b h d -> b h d 1" ) * A ) # (batch, nheads, dim, dstate)
253
+ B = repeat (B , "b g n -> b (g h) n" , h = nheads // ngroups ) # (batch, nheads, dstate)
254
+ C = repeat (C , "b g n -> b (g h) n" , h = nheads // ngroups ) # (batch, nheads, dstate)
255
+ dB = rearrange (dt , "b h d -> b h d 1" ) * rearrange (B , "b h n -> b h 1 n" ) # (batch, nheads, dim, dstate)
256
+ state .copy_ (state * dA + dB * rearrange (x , "b h d -> b h d 1" )) # (batch, dim, dstate
257
+ out = torch .einsum ("bhdn,bhn->bhd" , state .to (C .dtype ), C )
190
258
if D is not None :
191
259
out += (x * D ).to (out .dtype )
192
- return (out if z is None else out * F .silu (z )).to (x .dtype )
260
+ out = (out if z is None else out * F .silu (z )).to (x .dtype )
261
+ if not has_heads :
262
+ out = out .squeeze (1 )
263
+ return out
0 commit comments