Skip to content

Commit e7d16fc

Browse files
Sage Attention Enhancement (#1075)
# PR Brief 1. `is_casual` -> `is_causal` param fix 2. remove BLKQ, BLKK. The dispatch logical branch is fully depending on the `km` param. 3. Other naming fix, for better review. --------- Co-authored-by: nifeng <nemonameless@qq.com>
1 parent f5b0a0d commit e7d16fc

File tree

2 files changed

+76
-84
lines changed

2 files changed

+76
-84
lines changed

paddlemix/triton_ops/sageattn.py

100644100755
+73-83
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,16 @@
1313
@paddle_use_triton(
1414
key=["1"]
1515
)
16-
def sageattn_quant_per_block_int8_kernel(
17-
Input,
18-
Output,
19-
Scale,
20-
L,
21-
stride_iz,
22-
stride_ih,
23-
stride_in,
24-
stride_oz,
25-
stride_oh,
26-
stride_on,
27-
stride_sz,
28-
stride_sh,
29-
sm_scale,
30-
h_attn: tl.constexpr, # grid num, through compiling
31-
bsz: tl.constexpr, # grid num, through compiling
32-
C: tl.constexpr,
33-
BLK: tl.constexpr
34-
):
16+
def sageattn_quant_per_block_int8_kernel(Input, Output, Scale, L,
17+
stride_iz, stride_ih, stride_in,
18+
stride_oz, stride_oh, stride_on,
19+
stride_sz, stride_sh,
20+
sm_scale,
21+
h_attn: tl.constexpr, # grid num, through compiling
22+
bsz: tl.constexpr, # grid num, through compiling
23+
C: tl.constexpr,
24+
BLK: tl.constexpr
25+
):
3526
off_blk = tl.program_id(axis=0)
3627
off_h = tl.program_id(axis=1)
3728
off_b = tl.program_id(axis=2)
@@ -56,19 +47,17 @@ def sageattn_quant_per_block_int8_kernel(
5647
# per-block quant triton API
5748
def sageattn_quant_per_block_int8(x,
5849
km=None,
59-
BLKQ=128, BLKK=64,
50+
BLK=128,
6051
sm_scale=1.0,
61-
tensor_layout="HND", q_or_k="q"):
52+
tensor_layout="HND"):
6253
"""
6354
[params]
6455
x: paddle.Tensor, dtype in fp16 or bf16, this is usually q or k input tensor.
6556
km: paddle.Tensor, the mean tensor of k tensor. Must be provided when the `x` is k tensor.
66-
BLKQ: int, the BLK for computing q tensor. Default 128, which is an optimized value.
67-
BLKK: int, the BLK for computing k tensor. Default 64, which is an optimized value.
57+
BLK: int, the BLK for computing q & k tensor. Default 128 for q, 64 for k, which is an optimized value.
6858
sm_scale: float, the scale factor for dynamic quant.
6959
tensor_layout: string. Only in ['HND', 'NHD'], 'HND' -> [bsz, num_heads, seq_len, head_dim],
70-
'HND' -> [bsz, seq_len, num_heads, head_dim]
71-
q_or_k: string. Only in ['q', 'k'], which should be clarified when using this API.
60+
'HND' -> [bsz, seq_len, num_heads, head_dim]
7261
[Examples]
7362
batch_size = 2
7463
num_heads = 24
@@ -86,11 +75,11 @@ def sageattn_quant_per_block_int8(x,
8675
km = paddle.mean(k, axis=seq_dim, keepdim=True)
8776
8877
q_int8, q_scale = sageattn_quant_per_block_int8(
89-
q, km=None, BLKQ=BLKQ, BLKK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='q')
78+
q, km=None, BLK=BLKQ, sm_scale=sm_scale, tensor_layout=tensor_layout)
9079
k_int8, k_scale = sageattn_quant_per_block_int8(
91-
k, km=km, BLKQ=BLKQ, BLKK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='k')
80+
k, km=km, BLK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout)
9281
"""
93-
if km is not None and q_or_k == "k":
82+
if km is not None:
9483
x = x - km
9584

9685
if tensor_layout == "HND":
@@ -103,7 +92,7 @@ def sageattn_quant_per_block_int8(x,
10392
b, seq_len, h_attn, head_dim = x.shape
10493

10594
stride_iz, stride_ih, stride_in = head_dim * seq_len * h_attn, head_dim * 1, head_dim * h_attn
106-
stride_oz, stride_oh, stride_on = head_dim * seq_len * h_attn, head_dim * 1, head_dim * h_attn,
95+
stride_oz, stride_oh, stride_on = head_dim * seq_len * h_attn, head_dim * 1, head_dim * h_attn
10796
else:
10897
raise ValueError(f"Unknown tensor layout: {tensor_layout}")
10998

@@ -112,8 +101,7 @@ def sageattn_quant_per_block_int8(x,
112101

113102
L = seq_len
114103
C = head_dim
115-
BLK = BLKQ if q_or_k == "q" else BLKK
116-
sm_scale = sm_scale * 1.44269504 if q_or_k == "q" else 1.0
104+
sm_scale = sm_scale * 1.44269504 if km is None else 1.0
117105

118106
stride_sz = h_attn * ((seq_len + BLK - 1) // BLK)
119107
stride_sh = (seq_len + BLK - 1) // BLK
@@ -123,7 +111,7 @@ def sageattn_quant_per_block_int8(x,
123111
124112
auto input_tensor = x;
125113
auto input_shape = x.shape();
126-
114+
127115
// define params
128116
int b, h_attn, seq_len, head_dim;
129117
int stride_iz, stride_ih, stride_in;
@@ -169,14 +157,13 @@ def sageattn_quant_per_block_int8(x,
169157
else {
170158
throw std::runtime_error("Unsupported tensor layout");
171159
}
172-
int BLK = (q_or_k == std::string("q")) ? BLKQ : BLKK;
160+
173161
auto scale_tensor = paddle::empty({b, h_attn, (seq_len + BLK - 1) / BLK},
174162
paddle::DataType::FLOAT32,
175163
x.place());
176164
int L = seq_len;
177165
int stride_sz = scale_tensor.strides()[0];
178166
int stride_sh = scale_tensor.strides()[1];
179-
// int Grid = BLK;
180167
int bsz = b;
181168
"""
182169

@@ -190,9 +177,11 @@ def sageattn_quant_per_block_int8(x,
190177
# output_tensor & scale_tensor has beed defined in above areas
191178
prepare_ptr_for_triton_kernel = """
192179
// prepare tensor
193-
auto Input = get_tensor_ptr(x);
194-
auto Output = get_tensor_ptr(output_tensor);
195-
auto Scale = get_tensor_ptr(scale_tensor);
180+
CUdeviceptr input_ptrs[3] = {
181+
get_tensor_ptr(x),
182+
get_tensor_ptr(output_tensor),
183+
get_tensor_ptr(scale_tensor)
184+
};
196185
"""
197186
return_tensor_names = "output_tensor, scale_tensor"
198187

@@ -226,8 +215,8 @@ def sageattn_quant_per_block_int8(x,
226215

227216
if in_dynamic_or_pir_mode():
228217
outs = _C_ops._run_custom_op(
229-
op_name, x, km, BLKQ, BLKK,
230-
sm_scale, tensor_layout, q_or_k
218+
op_name, x, km, BLK,
219+
sm_scale, tensor_layout
231220
)
232221
return outs[0], outs[1]
233222
else:
@@ -243,11 +232,9 @@ def sageattn_quant_per_block_int8(x,
243232
type=op_name,
244233
inputs=inputs,
245234
attrs={
246-
"BLKQ": BLKQ,
247-
"BLKK": BLKK,
235+
"BLK": BLK,
248236
"sm_scale": sm_scale,
249237
"tensor_layout": tensor_layout,
250-
"q_or_k": q_or_k
251238
},
252239
outputs={"output_tensor": out_int8, "scale_tensor": out_scale}
253240
)
@@ -257,7 +244,7 @@ def sageattn_quant_per_block_int8(x,
257244
@paddle_use_triton(
258245
key=["1"]
259246
)
260-
def sageattn_attn_fwd_casual_false_kernel(
247+
def sageattn_attn_fwd_causal_false_kernel(
261248
Q, K, V, Q_scale, K_scale, Out, Lse,
262249
stride_qz, stride_qh, stride_qn,
263250
stride_kz, stride_kh, stride_kn,
@@ -332,7 +319,7 @@ def sageattn_attn_fwd_casual_false_kernel(
332319
tl.store(lse_ptrs, l_i, mask = (offs_m < qo_len))
333320

334321

335-
def sageattn_forward_casual_false(q, k, v,
322+
def sageattn_forward_causal_false(q, k, v,
336323
q_scale, k_scale,
337324
output_dtype="float16",
338325
tensor_layout="HND",
@@ -365,7 +352,7 @@ def sageattn_forward_casual_false(q, k, v,
365352
km = paddle.mean(k, axis=seq_dim, keepdim=True)
366353
367354
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
368-
o, lse = sageattn_forward_casual_false(q_int8, k_int8, v, q_scale, k_scale,
355+
o, lse = sageattn_forward_causal_false(q_int8, k_int8, v, q_scale, k_scale,
369356
output_dtype="float16", tensor_layout=tensor_layout)
370357
"""
371358
assert output_dtype in ["float16", "bfloat16"]
@@ -471,18 +458,18 @@ def sageattn_forward_casual_false(q, k, v,
471458
int BSZ = b;
472459
"""
473460

474-
op_name = "triton_sageattn_attn_fwd_casual_false"
461+
op_name = "triton_sageattn_attn_fwd_causal_false"
475462
op_name += get_dtype_str(q.dtype)
476463
op_name += f"_BSZ{BSZ}_seq{qo_len}_h{h_qo}_dim{HEAD_DIM_K}"
477464

478-
sageattn_attn_fwd_casual_false_config = []
465+
sageattn_attn_fwd_causal_false_config = []
479466
if head_dim == 64:
480-
sageattn_attn_fwd_casual_false_config.append({
467+
sageattn_attn_fwd_causal_false_config.append({
481468
"num_warps": 4,
482469
"num_stages": 3
483470
})
484471
else:
485-
sageattn_attn_fwd_casual_false_config.append({
472+
sageattn_attn_fwd_causal_false_config.append({
486473
"num_warps": 8,
487474
"num_stages": 4
488475
})
@@ -501,23 +488,25 @@ def sageattn_forward_casual_false(q, k, v,
501488
lse_tensor = paddle::empty({1,1,1}, paddle::DataType::FLOAT32, paddle::CPUPlace());
502489
}
503490
504-
auto Q = get_tensor_ptr(q);
505-
auto K = get_tensor_ptr(k);
506-
auto V = get_tensor_ptr(v);
507-
auto Q_scale = get_tensor_ptr(q_scale);
508-
auto K_scale = get_tensor_ptr(k_scale);
509-
auto Out = get_tensor_ptr(out_tensor);
510-
auto Lse = get_tensor_ptr(lse_tensor);
491+
CUdeviceptr input_ptrs[7] = {
492+
get_tensor_ptr(q),
493+
get_tensor_ptr(k),
494+
get_tensor_ptr(v),
495+
get_tensor_ptr(q_scale),
496+
get_tensor_ptr(k_scale),
497+
get_tensor_ptr(out_tensor),
498+
get_tensor_ptr(lse_tensor)
499+
};
511500
"""
512501
return_tensor_names = "out_tensor, lse_tensor"
513502
template_used = rendering_common_template(
514-
sageattn_forward_casual_false,
503+
sageattn_forward_causal_false,
515504
prepare_attr_for_triton_kernel=prepare_attr_for_triton_kernel,
516505
prepare_ptr_for_triton_kernel=prepare_ptr_for_triton_kernel,
517506
return_tensor_names=return_tensor_names
518507
)
519508
grid = ("(qo_len + BLOCK_M - 1) / BLOCK_M", "h_qo", "BSZ")
520-
sageattn_attn_fwd_casual_false_kernel[(op_name, template_used, grid, sageattn_attn_fwd_casual_false_config)](
509+
sageattn_attn_fwd_causal_false_kernel[(op_name, template_used, grid, sageattn_attn_fwd_causal_false_config)](
521510
Q=q,
522511
K=k,
523512
V=v,
@@ -590,7 +579,7 @@ def sageattn_forward_casual_false(q, k, v,
590579
@paddle_use_triton(
591580
key=["1"]
592581
)
593-
def sageattn_attn_fwd_casual_true_kernel(
582+
def sageattn_attn_fwd_causal_true_kernel(
594583
Q, K, V, Q_scale, K_scale, Out, Lse,
595584
stride_qz, stride_qh, stride_qn,
596585
stride_kz, stride_kh, stride_kn,
@@ -721,7 +710,7 @@ def sageattn_attn_fwd_casual_true_kernel(
721710
tl.store(lse_ptrs, l_i, mask = (offs_m < qo_len))
722711

723712

724-
def sageattn_forward_casual_true(q, k, v,
713+
def sageattn_forward_causal_true(q, k, v,
725714
q_scale, k_scale,
726715
output_dtype="float16",
727716
tensor_layout="HND",
@@ -754,7 +743,7 @@ def sageattn_forward_casual_true(q, k, v,
754743
km = paddle.mean(k, axis=seq_dim, keepdim=True)
755744
756745
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
757-
o, lse = sageattn_forward_casual_true(q_int8, k_int8, v, q_scale, k_scale,
746+
o, lse = sageattn_forward_causal_true(q_int8, k_int8, v, q_scale, k_scale,
758747
output_dtype="float16", tensor_layout=tensor_layout)
759748
"""
760749
assert output_dtype in ["float16", "bfloat16"]
@@ -862,18 +851,18 @@ def sageattn_forward_casual_true(q, k, v,
862851
int BSZ = b;
863852
"""
864853

865-
op_name = "triton_sageattn_attn_fwd_casual_true"
854+
op_name = "triton_sageattn_attn_fwd_causal_true"
866855
op_name += get_dtype_str(q.dtype)
867856
op_name += f"_BSZ{BSZ}_seq{qo_len}_h{h_qo}_dim{HEAD_DIM_K}"
868857

869-
sageattn_attn_fwd_casual_true_config = []
858+
sageattn_attn_fwd_causal_true_config = []
870859
if head_dim == 64:
871-
sageattn_attn_fwd_casual_true_config.append({
860+
sageattn_attn_fwd_causal_true_config.append({
872861
"num_warps": 4,
873862
"num_stages": 4
874863
})
875864
else:
876-
sageattn_attn_fwd_casual_true_config.append({
865+
sageattn_attn_fwd_causal_true_config.append({
877866
"num_warps": 8,
878867
"num_stages": 4
879868
})
@@ -891,24 +880,26 @@ def sageattn_forward_casual_true(q, k, v,
891880
} else {
892881
lse_tensor = paddle::empty({1,1,1}, paddle::DataType::FLOAT32, paddle::CPUPlace());
893882
}
894-
895-
auto Q = get_tensor_ptr(q);
896-
auto K = get_tensor_ptr(k);
897-
auto V = get_tensor_ptr(v);
898-
auto Q_scale = get_tensor_ptr(q_scale);
899-
auto K_scale = get_tensor_ptr(k_scale);
900-
auto Out = get_tensor_ptr(out_tensor);
901-
auto Lse = get_tensor_ptr(lse_tensor);
883+
884+
CUdeviceptr input_ptrs[7] = {
885+
get_tensor_ptr(q),
886+
get_tensor_ptr(k),
887+
get_tensor_ptr(v),
888+
get_tensor_ptr(q_scale),
889+
get_tensor_ptr(k_scale),
890+
get_tensor_ptr(out_tensor),
891+
get_tensor_ptr(lse_tensor)
892+
};
902893
"""
903894
return_tensor_names = "out_tensor, lse_tensor"
904895
template_used = rendering_common_template(
905-
sageattn_forward_casual_true,
896+
sageattn_forward_causal_true,
906897
prepare_attr_for_triton_kernel=prepare_attr_for_triton_kernel,
907898
prepare_ptr_for_triton_kernel=prepare_ptr_for_triton_kernel,
908899
return_tensor_names=return_tensor_names
909900
)
910901
grid = ("(qo_len + BLOCK_M - 1) / BLOCK_M", "h_qo", "BSZ")
911-
sageattn_attn_fwd_casual_true_kernel[(op_name, template_used, grid, sageattn_attn_fwd_casual_true_config)](
902+
sageattn_attn_fwd_causal_true_kernel[(op_name, template_used, grid, sageattn_attn_fwd_causal_true_config)](
912903
Q=q,
913904
K=k,
914905
V=v,
@@ -982,9 +973,9 @@ def sageattn_forward_casual_true(q, k, v,
982973
def per_block_int8(q, k, km=None, BLKQ=128, BLKK=64, sm_scale=None,
983974
tensor_layout="HND"):
984975
q_int8, q_scale = sageattn_quant_per_block_int8(
985-
q, km=None, BLKQ=BLKQ, BLKK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='q')
976+
q, km=None, BLK=BLKQ, sm_scale=sm_scale, tensor_layout=tensor_layout)
986977
k_int8, k_scale = sageattn_quant_per_block_int8(
987-
k, km=km, BLKQ=BLKQ, BLKK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='k')
978+
k, km=km, BLK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout)
988979
return q_int8, q_scale, k_int8, k_scale
989980

990981

@@ -993,11 +984,10 @@ def sageattn_qk_int8_pv_fp16_triton(
993984
k: paddle.Tensor,
994985
v: paddle.Tensor,
995986
tensor_layout: str = "HND",
996-
is_casual: bool = False,
987+
is_causal: bool = False,
997988
sm_scale: Optional[float] = None,
998989
smooth_k: bool = True,
999990
return_lse: bool = False,
1000-
**kwargs
1001991
) -> paddle.Tensor:
1002992
"""
1003993
Examples:
@@ -1010,7 +1000,7 @@ def sageattn_qk_int8_pv_fp16_triton(
10101000
v = paddle.randn(shape=(batch_size, seq_len, num_heads, head_dim), dtype="float16")
10111001
sm_scale = 1 / (head_dim ** 0.5)
10121002
1013-
o = paddlemix.triton_ops.sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout="NHD", is_casual=False, sm_scale=sm_scale, smooth_k=True, return_lse=False)
1003+
o = paddlemix.triton_ops.sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout="NHD", is_causal=False, sm_scale=sm_scale, smooth_k=True, return_lse=False)
10141004
"""
10151005
dtype = q.dtype
10161006
assert dtype in [paddle.float16, paddle.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
@@ -1050,10 +1040,10 @@ def sageattn_qk_int8_pv_fp16_triton(
10501040

10511041
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
10521042

1053-
if is_casual:
1054-
o, lse = sageattn_forward_casual_true(q_int8, k_int8, v, q_scale, k_scale, output_dtype="float16", tensor_layout=tensor_layout, return_lse=return_lse)
1043+
if is_causal:
1044+
o, lse = sageattn_forward_causal_true(q_int8, k_int8, v, q_scale, k_scale, output_dtype="float16", tensor_layout=tensor_layout, return_lse=return_lse)
10551045
else:
1056-
o, lse = sageattn_forward_casual_false(q_int8, k_int8, v, q_scale, k_scale, output_dtype="float16", tensor_layout=tensor_layout, return_lse=return_lse)
1046+
o, lse = sageattn_forward_causal_false(q_int8, k_int8, v, q_scale, k_scale, output_dtype="float16", tensor_layout=tensor_layout, return_lse=return_lse)
10571047

10581048
o = o[..., :head_dim_og]
10591049

0 commit comments

Comments
 (0)