@@ -88,6 +88,7 @@ def naive_attention_impl(
88
88
scale = 1.0 ,
89
89
cache_k_dequant_scales = None ,
90
90
cache_v_dequant_scales = None ,
91
+ use_cachekv_int8 = "None" ,
91
92
):
92
93
batch = query .shape [0 ]
93
94
heads = query .shape [1 ]
@@ -98,13 +99,18 @@ def naive_attention_impl(
98
99
key = key .reshape ([batch , kv_head , 1 , seq_len , head_dim ])
99
100
key = paddle .tile (key , [1 , 1 , heads // kv_head , 1 , 1 ])
100
101
key = key .reshape ([batch , heads , seq_len , head_dim ])
102
+
103
+ if use_cachekv_int8 == "dynamic" :
104
+ unsqueeze_shape = [2 , 3 ]
105
+ elif use_cachekv_int8 == "static" :
106
+ unsqueeze_shape = [0 , 2 , 3 ]
101
107
if pre_cache_k is not None :
102
108
key = paddle .concat ([pre_cache_k , key ], axis = 2 )
103
109
if cache_k is not None :
104
110
if cache_k_dequant_scales is not None :
105
111
dequant_cache_k = (
106
112
(cache_k .astype ('float32' ) - 128.0 )
107
- * cache_k_dequant_scales .unsqueeze ([ 0 , 2 , 3 ] )
113
+ * cache_k_dequant_scales .unsqueeze (unsqueeze_shape )
108
114
).astype (key .dtype )
109
115
key = paddle .concat ([dequant_cache_k , key ], axis = 2 )
110
116
else :
@@ -119,7 +125,7 @@ def naive_attention_impl(
119
125
if cache_v_dequant_scales is not None :
120
126
dequant_cache_v = (
121
127
(cache_v .astype ('float32' ) - 128.0 )
122
- * cache_v_dequant_scales .unsqueeze ([ 0 , 2 , 3 ] )
128
+ * cache_v_dequant_scales .unsqueeze (unsqueeze_shape )
123
129
).astype (value .dtype )
124
130
value = paddle .concat ([dequant_cache_v , value ], axis = 2 )
125
131
else :
@@ -1306,6 +1312,13 @@ def test_all(self):
1306
1312
)
1307
1313
1308
1314
1315
+ @unittest .skipIf (
1316
+ not core .is_compiled_with_cuda ()
1317
+ or get_cuda_version () < 11040
1318
+ or not is_sm_supported ,
1319
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
1320
+ "and device's compute capability must be 8.x or 90" ,
1321
+ )
1309
1322
class TestBlockMultiHeadAttnEncDecPTQDequant (unittest .TestCase ):
1310
1323
def setUp (self ):
1311
1324
paddle .disable_static ()
@@ -1641,6 +1654,13 @@ def test_all(self):
1641
1654
)
1642
1655
1643
1656
1657
+ @unittest .skipIf (
1658
+ not core .is_compiled_with_cuda ()
1659
+ or get_cuda_version () < 11040
1660
+ or not is_sm_supported ,
1661
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
1662
+ "and device's compute capability must be 8.x or 90" ,
1663
+ )
1644
1664
class TestBlockMultiHeadAttnEncDecPTQDequantQuantShiftSmooth (unittest .TestCase ):
1645
1665
def setUp (self ):
1646
1666
paddle .disable_static ()
@@ -2013,6 +2033,13 @@ def test_all(self):
2013
2033
)
2014
2034
2015
2035
2036
+ @unittest .skipIf (
2037
+ not core .is_compiled_with_cuda ()
2038
+ or get_cuda_version () < 11040
2039
+ or not is_sm_supported ,
2040
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2041
+ "and device's compute capability must be 8.x or 90" ,
2042
+ )
2016
2043
class TestBlockMultiHeadAttnEncDecQuant (unittest .TestCase ):
2017
2044
def setUp (self ):
2018
2045
paddle .disable_static ()
@@ -2282,6 +2309,13 @@ def test_all(self):
2282
2309
)
2283
2310
2284
2311
2312
+ @unittest .skipIf (
2313
+ not core .is_compiled_with_cuda ()
2314
+ or get_cuda_version () < 11040
2315
+ or not is_sm_supported ,
2316
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2317
+ "and device's compute capability must be 8.x or 90" ,
2318
+ )
2285
2319
class TestBlockMultiHeadAttnEncDecCacheKVDynamicQuant (unittest .TestCase ):
2286
2320
def setUp (self ):
2287
2321
paddle .disable_static ()
@@ -2339,16 +2373,16 @@ def setUp(self):
2339
2373
self .cache_k = paddle .zeros (shape = self .cache_shape , dtype = 'uint8' )
2340
2374
self .cache_v = paddle .zeros (shape = self .cache_shape , dtype = 'uint8' )
2341
2375
self .cache_k_quant_scales = paddle .zeros (
2342
- shape = [self .num_head ], dtype = 'float32'
2376
+ shape = [self .batch_size , self . num_head ], dtype = 'float32'
2343
2377
)
2344
2378
self .cache_v_quant_scales = paddle .zeros (
2345
- shape = [self .num_head ], dtype = 'float32'
2379
+ shape = [self .batch_size , self . num_head ], dtype = 'float32'
2346
2380
)
2347
2381
self .cache_k_dequant_scales = paddle .zeros (
2348
- shape = [self .num_head ], dtype = 'float32'
2382
+ shape = [self .batch_size , self . num_head ], dtype = 'float32'
2349
2383
)
2350
2384
self .cache_v_dequant_scales = paddle .zeros (
2351
- shape = [self .num_head ], dtype = 'float32'
2385
+ shape = [self .batch_size , self . num_head ], dtype = 'float32'
2352
2386
)
2353
2387
2354
2388
self .block_tables = paddle .zeros (
@@ -2510,6 +2544,7 @@ def test_all(self):
2510
2544
self .scale ,
2511
2545
cache_k_dequant_scales = self .cache_k_dequant_scales ,
2512
2546
cache_v_dequant_scales = self .cache_v_dequant_scales ,
2547
+ use_cachekv_int8 = "dynamic" ,
2513
2548
)
2514
2549
.transpose ([0 , 2 , 1 , 3 ])
2515
2550
.reshape ([self .batch_size , - 1 ])
@@ -2555,6 +2590,13 @@ def test_all(self):
2555
2590
)
2556
2591
2557
2592
2593
+ @unittest .skipIf (
2594
+ not core .is_compiled_with_cuda ()
2595
+ or get_cuda_version () < 11040
2596
+ or not is_sm_supported ,
2597
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2598
+ "and device's compute capability must be 8.x or 90" ,
2599
+ )
2558
2600
class TestBlockMultiHeadAttnEncDecCacheKVStaticQuant (unittest .TestCase ):
2559
2601
def setUp (self ):
2560
2602
paddle .disable_static ()
@@ -2795,6 +2837,7 @@ def test_all(self):
2795
2837
self .scale ,
2796
2838
cache_k_dequant_scales = self .cache_k_dequant_scales ,
2797
2839
cache_v_dequant_scales = self .cache_v_dequant_scales ,
2840
+ use_cachekv_int8 = "static" ,
2798
2841
)
2799
2842
.transpose ([0 , 2 , 1 , 3 ])
2800
2843
.reshape ([self .batch_size , - 1 ])
0 commit comments