13
13
@paddle_use_triton (
14
14
key = ["1" ]
15
15
)
16
- def sageattn_quant_per_block_int8_kernel (
17
- Input ,
18
- Output ,
19
- Scale ,
20
- L ,
21
- stride_iz ,
22
- stride_ih ,
23
- stride_in ,
24
- stride_oz ,
25
- stride_oh ,
26
- stride_on ,
27
- stride_sz ,
28
- stride_sh ,
29
- sm_scale ,
30
- h_attn : tl .constexpr , # grid num, through compiling
31
- bsz : tl .constexpr , # grid num, through compiling
32
- C : tl .constexpr ,
33
- BLK : tl .constexpr
34
- ):
16
+ def sageattn_quant_per_block_int8_kernel (Input , Output , Scale , L ,
17
+ stride_iz , stride_ih , stride_in ,
18
+ stride_oz , stride_oh , stride_on ,
19
+ stride_sz , stride_sh ,
20
+ sm_scale ,
21
+ h_attn : tl .constexpr , # grid num, through compiling
22
+ bsz : tl .constexpr , # grid num, through compiling
23
+ C : tl .constexpr ,
24
+ BLK : tl .constexpr
25
+ ):
35
26
off_blk = tl .program_id (axis = 0 )
36
27
off_h = tl .program_id (axis = 1 )
37
28
off_b = tl .program_id (axis = 2 )
@@ -56,19 +47,17 @@ def sageattn_quant_per_block_int8_kernel(
56
47
# per-block quant triton API
57
48
def sageattn_quant_per_block_int8 (x ,
58
49
km = None ,
59
- BLKQ = 128 , BLKK = 64 ,
50
+ BLK = 128 ,
60
51
sm_scale = 1.0 ,
61
- tensor_layout = "HND" , q_or_k = "q" ):
52
+ tensor_layout = "HND" ):
62
53
"""
63
54
[params]
64
55
x: paddle.Tensor, dtype in fp16 or bf16, this is usually q or k input tensor.
65
56
km: paddle.Tensor, the mean tensor of k tensor. Must be provided when the `x` is k tensor.
66
- BLKQ: int, the BLK for computing q tensor. Default 128, which is an optimized value.
67
- BLKK: int, the BLK for computing k tensor. Default 64, which is an optimized value.
57
+ BLK: int, the BLK for computing q & k tensor. Default 128 for q, 64 for k, which is an optimized value.
68
58
sm_scale: float, the scale factor for dynamic quant.
69
59
tensor_layout: string. Only in ['HND', 'NHD'], 'HND' -> [bsz, num_heads, seq_len, head_dim],
70
- 'HND' -> [bsz, seq_len, num_heads, head_dim]
71
- q_or_k: string. Only in ['q', 'k'], which should be clarified when using this API.
60
+ 'HND' -> [bsz, seq_len, num_heads, head_dim]
72
61
[Examples]
73
62
batch_size = 2
74
63
num_heads = 24
@@ -86,11 +75,11 @@ def sageattn_quant_per_block_int8(x,
86
75
km = paddle.mean(k, axis=seq_dim, keepdim=True)
87
76
88
77
q_int8, q_scale = sageattn_quant_per_block_int8(
89
- q, km=None, BLKQ =BLKQ, BLKK=BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='q' )
78
+ q, km=None, BLK =BLKQ, sm_scale=sm_scale, tensor_layout=tensor_layout)
90
79
k_int8, k_scale = sageattn_quant_per_block_int8(
91
- k, km=km, BLKQ=BLKQ, BLKK=BLKK , sm_scale=sm_scale, tensor_layout=tensor_layout, q_or_k='k' )
80
+ k, km=km, BLK= BLKK, sm_scale=sm_scale, tensor_layout=tensor_layout)
92
81
"""
93
- if km is not None and q_or_k == "k" :
82
+ if km is not None :
94
83
x = x - km
95
84
96
85
if tensor_layout == "HND" :
@@ -103,7 +92,7 @@ def sageattn_quant_per_block_int8(x,
103
92
b , seq_len , h_attn , head_dim = x .shape
104
93
105
94
stride_iz , stride_ih , stride_in = head_dim * seq_len * h_attn , head_dim * 1 , head_dim * h_attn
106
- stride_oz , stride_oh , stride_on = head_dim * seq_len * h_attn , head_dim * 1 , head_dim * h_attn ,
95
+ stride_oz , stride_oh , stride_on = head_dim * seq_len * h_attn , head_dim * 1 , head_dim * h_attn
107
96
else :
108
97
raise ValueError (f"Unknown tensor layout: { tensor_layout } " )
109
98
@@ -112,8 +101,7 @@ def sageattn_quant_per_block_int8(x,
112
101
113
102
L = seq_len
114
103
C = head_dim
115
- BLK = BLKQ if q_or_k == "q" else BLKK
116
- sm_scale = sm_scale * 1.44269504 if q_or_k == "q" else 1.0
104
+ sm_scale = sm_scale * 1.44269504 if km is None else 1.0
117
105
118
106
stride_sz = h_attn * ((seq_len + BLK - 1 ) // BLK )
119
107
stride_sh = (seq_len + BLK - 1 ) // BLK
@@ -123,7 +111,7 @@ def sageattn_quant_per_block_int8(x,
123
111
124
112
auto input_tensor = x;
125
113
auto input_shape = x.shape();
126
-
114
+
127
115
// define params
128
116
int b, h_attn, seq_len, head_dim;
129
117
int stride_iz, stride_ih, stride_in;
@@ -169,14 +157,13 @@ def sageattn_quant_per_block_int8(x,
169
157
else {
170
158
throw std::runtime_error("Unsupported tensor layout");
171
159
}
172
- int BLK = (q_or_k == std::string("q")) ? BLKQ : BLKK;
160
+
173
161
auto scale_tensor = paddle::empty({b, h_attn, (seq_len + BLK - 1) / BLK},
174
162
paddle::DataType::FLOAT32,
175
163
x.place());
176
164
int L = seq_len;
177
165
int stride_sz = scale_tensor.strides()[0];
178
166
int stride_sh = scale_tensor.strides()[1];
179
- // int Grid = BLK;
180
167
int bsz = b;
181
168
"""
182
169
@@ -190,9 +177,11 @@ def sageattn_quant_per_block_int8(x,
190
177
# output_tensor & scale_tensor has beed defined in above areas
191
178
prepare_ptr_for_triton_kernel = """
192
179
// prepare tensor
193
- auto Input = get_tensor_ptr(x);
194
- auto Output = get_tensor_ptr(output_tensor);
195
- auto Scale = get_tensor_ptr(scale_tensor);
180
+ CUdeviceptr input_ptrs[3] = {
181
+ get_tensor_ptr(x),
182
+ get_tensor_ptr(output_tensor),
183
+ get_tensor_ptr(scale_tensor)
184
+ };
196
185
"""
197
186
return_tensor_names = "output_tensor, scale_tensor"
198
187
@@ -226,8 +215,8 @@ def sageattn_quant_per_block_int8(x,
226
215
227
216
if in_dynamic_or_pir_mode ():
228
217
outs = _C_ops ._run_custom_op (
229
- op_name , x , km , BLKQ , BLKK ,
230
- sm_scale , tensor_layout , q_or_k
218
+ op_name , x , km , BLK ,
219
+ sm_scale , tensor_layout
231
220
)
232
221
return outs [0 ], outs [1 ]
233
222
else :
@@ -243,11 +232,9 @@ def sageattn_quant_per_block_int8(x,
243
232
type = op_name ,
244
233
inputs = inputs ,
245
234
attrs = {
246
- "BLKQ" : BLKQ ,
247
- "BLKK" : BLKK ,
235
+ "BLK" : BLK ,
248
236
"sm_scale" : sm_scale ,
249
237
"tensor_layout" : tensor_layout ,
250
- "q_or_k" : q_or_k
251
238
},
252
239
outputs = {"output_tensor" : out_int8 , "scale_tensor" : out_scale }
253
240
)
@@ -257,7 +244,7 @@ def sageattn_quant_per_block_int8(x,
257
244
@paddle_use_triton (
258
245
key = ["1" ]
259
246
)
260
- def sageattn_attn_fwd_casual_false_kernel (
247
+ def sageattn_attn_fwd_causal_false_kernel (
261
248
Q , K , V , Q_scale , K_scale , Out , Lse ,
262
249
stride_qz , stride_qh , stride_qn ,
263
250
stride_kz , stride_kh , stride_kn ,
@@ -332,7 +319,7 @@ def sageattn_attn_fwd_casual_false_kernel(
332
319
tl .store (lse_ptrs , l_i , mask = (offs_m < qo_len ))
333
320
334
321
335
- def sageattn_forward_casual_false (q , k , v ,
322
+ def sageattn_forward_causal_false (q , k , v ,
336
323
q_scale , k_scale ,
337
324
output_dtype = "float16" ,
338
325
tensor_layout = "HND" ,
@@ -365,7 +352,7 @@ def sageattn_forward_casual_false(q, k, v,
365
352
km = paddle.mean(k, axis=seq_dim, keepdim=True)
366
353
367
354
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
368
- o, lse = sageattn_forward_casual_false (q_int8, k_int8, v, q_scale, k_scale,
355
+ o, lse = sageattn_forward_causal_false (q_int8, k_int8, v, q_scale, k_scale,
369
356
output_dtype="float16", tensor_layout=tensor_layout)
370
357
"""
371
358
assert output_dtype in ["float16" , "bfloat16" ]
@@ -471,18 +458,18 @@ def sageattn_forward_casual_false(q, k, v,
471
458
int BSZ = b;
472
459
"""
473
460
474
- op_name = "triton_sageattn_attn_fwd_casual_false "
461
+ op_name = "triton_sageattn_attn_fwd_causal_false "
475
462
op_name += get_dtype_str (q .dtype )
476
463
op_name += f"_BSZ{ BSZ } _seq{ qo_len } _h{ h_qo } _dim{ HEAD_DIM_K } "
477
464
478
- sageattn_attn_fwd_casual_false_config = []
465
+ sageattn_attn_fwd_causal_false_config = []
479
466
if head_dim == 64 :
480
- sageattn_attn_fwd_casual_false_config .append ({
467
+ sageattn_attn_fwd_causal_false_config .append ({
481
468
"num_warps" : 4 ,
482
469
"num_stages" : 3
483
470
})
484
471
else :
485
- sageattn_attn_fwd_casual_false_config .append ({
472
+ sageattn_attn_fwd_causal_false_config .append ({
486
473
"num_warps" : 8 ,
487
474
"num_stages" : 4
488
475
})
@@ -501,23 +488,25 @@ def sageattn_forward_casual_false(q, k, v,
501
488
lse_tensor = paddle::empty({1,1,1}, paddle::DataType::FLOAT32, paddle::CPUPlace());
502
489
}
503
490
504
- auto Q = get_tensor_ptr(q);
505
- auto K = get_tensor_ptr(k);
506
- auto V = get_tensor_ptr(v);
507
- auto Q_scale = get_tensor_ptr(q_scale);
508
- auto K_scale = get_tensor_ptr(k_scale);
509
- auto Out = get_tensor_ptr(out_tensor);
510
- auto Lse = get_tensor_ptr(lse_tensor);
491
+ CUdeviceptr input_ptrs[7] = {
492
+ get_tensor_ptr(q),
493
+ get_tensor_ptr(k),
494
+ get_tensor_ptr(v),
495
+ get_tensor_ptr(q_scale),
496
+ get_tensor_ptr(k_scale),
497
+ get_tensor_ptr(out_tensor),
498
+ get_tensor_ptr(lse_tensor)
499
+ };
511
500
"""
512
501
return_tensor_names = "out_tensor, lse_tensor"
513
502
template_used = rendering_common_template (
514
- sageattn_forward_casual_false ,
503
+ sageattn_forward_causal_false ,
515
504
prepare_attr_for_triton_kernel = prepare_attr_for_triton_kernel ,
516
505
prepare_ptr_for_triton_kernel = prepare_ptr_for_triton_kernel ,
517
506
return_tensor_names = return_tensor_names
518
507
)
519
508
grid = ("(qo_len + BLOCK_M - 1) / BLOCK_M" , "h_qo" , "BSZ" )
520
- sageattn_attn_fwd_casual_false_kernel [(op_name , template_used , grid , sageattn_attn_fwd_casual_false_config )](
509
+ sageattn_attn_fwd_causal_false_kernel [(op_name , template_used , grid , sageattn_attn_fwd_causal_false_config )](
521
510
Q = q ,
522
511
K = k ,
523
512
V = v ,
@@ -590,7 +579,7 @@ def sageattn_forward_casual_false(q, k, v,
590
579
@paddle_use_triton (
591
580
key = ["1" ]
592
581
)
593
- def sageattn_attn_fwd_casual_true_kernel (
582
+ def sageattn_attn_fwd_causal_true_kernel (
594
583
Q , K , V , Q_scale , K_scale , Out , Lse ,
595
584
stride_qz , stride_qh , stride_qn ,
596
585
stride_kz , stride_kh , stride_kn ,
@@ -721,7 +710,7 @@ def sageattn_attn_fwd_casual_true_kernel(
721
710
tl .store (lse_ptrs , l_i , mask = (offs_m < qo_len ))
722
711
723
712
724
- def sageattn_forward_casual_true (q , k , v ,
713
+ def sageattn_forward_causal_true (q , k , v ,
725
714
q_scale , k_scale ,
726
715
output_dtype = "float16" ,
727
716
tensor_layout = "HND" ,
@@ -754,7 +743,7 @@ def sageattn_forward_casual_true(q, k, v,
754
743
km = paddle.mean(k, axis=seq_dim, keepdim=True)
755
744
756
745
q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, km=km, sm_scale=sm_scale, tensor_layout=tensor_layout)
757
- o, lse = sageattn_forward_casual_true (q_int8, k_int8, v, q_scale, k_scale,
746
+ o, lse = sageattn_forward_causal_true (q_int8, k_int8, v, q_scale, k_scale,
758
747
output_dtype="float16", tensor_layout=tensor_layout)
759
748
"""
760
749
assert output_dtype in ["float16" , "bfloat16" ]
@@ -862,18 +851,18 @@ def sageattn_forward_casual_true(q, k, v,
862
851
int BSZ = b;
863
852
"""
864
853
865
- op_name = "triton_sageattn_attn_fwd_casual_true "
854
+ op_name = "triton_sageattn_attn_fwd_causal_true "
866
855
op_name += get_dtype_str (q .dtype )
867
856
op_name += f"_BSZ{ BSZ } _seq{ qo_len } _h{ h_qo } _dim{ HEAD_DIM_K } "
868
857
869
- sageattn_attn_fwd_casual_true_config = []
858
+ sageattn_attn_fwd_causal_true_config = []
870
859
if head_dim == 64 :
871
- sageattn_attn_fwd_casual_true_config .append ({
860
+ sageattn_attn_fwd_causal_true_config .append ({
872
861
"num_warps" : 4 ,
873
862
"num_stages" : 4
874
863
})
875
864
else :
876
- sageattn_attn_fwd_casual_true_config .append ({
865
+ sageattn_attn_fwd_causal_true_config .append ({
877
866
"num_warps" : 8 ,
878
867
"num_stages" : 4
879
868
})
@@ -891,24 +880,26 @@ def sageattn_forward_casual_true(q, k, v,
891
880
} else {
892
881
lse_tensor = paddle::empty({1,1,1}, paddle::DataType::FLOAT32, paddle::CPUPlace());
893
882
}
894
-
895
- auto Q = get_tensor_ptr(q);
896
- auto K = get_tensor_ptr(k);
897
- auto V = get_tensor_ptr(v);
898
- auto Q_scale = get_tensor_ptr(q_scale);
899
- auto K_scale = get_tensor_ptr(k_scale);
900
- auto Out = get_tensor_ptr(out_tensor);
901
- auto Lse = get_tensor_ptr(lse_tensor);
883
+
884
+ CUdeviceptr input_ptrs[7] = {
885
+ get_tensor_ptr(q),
886
+ get_tensor_ptr(k),
887
+ get_tensor_ptr(v),
888
+ get_tensor_ptr(q_scale),
889
+ get_tensor_ptr(k_scale),
890
+ get_tensor_ptr(out_tensor),
891
+ get_tensor_ptr(lse_tensor)
892
+ };
902
893
"""
903
894
return_tensor_names = "out_tensor, lse_tensor"
904
895
template_used = rendering_common_template (
905
- sageattn_forward_casual_true ,
896
+ sageattn_forward_causal_true ,
906
897
prepare_attr_for_triton_kernel = prepare_attr_for_triton_kernel ,
907
898
prepare_ptr_for_triton_kernel = prepare_ptr_for_triton_kernel ,
908
899
return_tensor_names = return_tensor_names
909
900
)
910
901
grid = ("(qo_len + BLOCK_M - 1) / BLOCK_M" , "h_qo" , "BSZ" )
911
- sageattn_attn_fwd_casual_true_kernel [(op_name , template_used , grid , sageattn_attn_fwd_casual_true_config )](
902
+ sageattn_attn_fwd_causal_true_kernel [(op_name , template_used , grid , sageattn_attn_fwd_causal_true_config )](
912
903
Q = q ,
913
904
K = k ,
914
905
V = v ,
@@ -982,9 +973,9 @@ def sageattn_forward_casual_true(q, k, v,
982
973
def per_block_int8 (q , k , km = None , BLKQ = 128 , BLKK = 64 , sm_scale = None ,
983
974
tensor_layout = "HND" ):
984
975
q_int8 , q_scale = sageattn_quant_per_block_int8 (
985
- q , km = None , BLKQ = BLKQ , BLKK = BLKK , sm_scale = sm_scale , tensor_layout = tensor_layout , q_or_k = 'q' )
976
+ q , km = None , BLK = BLKQ , sm_scale = sm_scale , tensor_layout = tensor_layout )
986
977
k_int8 , k_scale = sageattn_quant_per_block_int8 (
987
- k , km = km , BLKQ = BLKQ , BLKK = BLKK , sm_scale = sm_scale , tensor_layout = tensor_layout , q_or_k = 'k' )
978
+ k , km = km , BLK = BLKK , sm_scale = sm_scale , tensor_layout = tensor_layout )
988
979
return q_int8 , q_scale , k_int8 , k_scale
989
980
990
981
@@ -993,11 +984,10 @@ def sageattn_qk_int8_pv_fp16_triton(
993
984
k : paddle .Tensor ,
994
985
v : paddle .Tensor ,
995
986
tensor_layout : str = "HND" ,
996
- is_casual : bool = False ,
987
+ is_causal : bool = False ,
997
988
sm_scale : Optional [float ] = None ,
998
989
smooth_k : bool = True ,
999
990
return_lse : bool = False ,
1000
- ** kwargs
1001
991
) -> paddle .Tensor :
1002
992
"""
1003
993
Examples:
@@ -1010,7 +1000,7 @@ def sageattn_qk_int8_pv_fp16_triton(
1010
1000
v = paddle.randn(shape=(batch_size, seq_len, num_heads, head_dim), dtype="float16")
1011
1001
sm_scale = 1 / (head_dim ** 0.5)
1012
1002
1013
- o = paddlemix.triton_ops.sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout="NHD", is_casual =False, sm_scale=sm_scale, smooth_k=True, return_lse=False)
1003
+ o = paddlemix.triton_ops.sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout="NHD", is_causal =False, sm_scale=sm_scale, smooth_k=True, return_lse=False)
1014
1004
"""
1015
1005
dtype = q .dtype
1016
1006
assert dtype in [paddle .float16 , paddle .bfloat16 ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
@@ -1050,10 +1040,10 @@ def sageattn_qk_int8_pv_fp16_triton(
1050
1040
1051
1041
q_int8 , q_scale , k_int8 , k_scale = per_block_int8 (q , k , km = km , sm_scale = sm_scale , tensor_layout = tensor_layout )
1052
1042
1053
- if is_casual :
1054
- o , lse = sageattn_forward_casual_true (q_int8 , k_int8 , v , q_scale , k_scale , output_dtype = "float16" , tensor_layout = tensor_layout , return_lse = return_lse )
1043
+ if is_causal :
1044
+ o , lse = sageattn_forward_causal_true (q_int8 , k_int8 , v , q_scale , k_scale , output_dtype = "float16" , tensor_layout = tensor_layout , return_lse = return_lse )
1055
1045
else :
1056
- o , lse = sageattn_forward_casual_false (q_int8 , k_int8 , v , q_scale , k_scale , output_dtype = "float16" , tensor_layout = tensor_layout , return_lse = return_lse )
1046
+ o , lse = sageattn_forward_causal_false (q_int8 , k_int8 , v , q_scale , k_scale , output_dtype = "float16" , tensor_layout = tensor_layout , return_lse = return_lse )
1057
1047
1058
1048
o = o [..., :head_dim_og ]
1059
1049
0 commit comments