Skip to content

Commit ccc4545

Browse files
fix ci build error and fix llama 65b error
1 parent 245476c commit ccc4545

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ void TensorCopySync(const phi::DenseTensor& src,
307307
return;
308308
}
309309

310-
VLOG(0) << "TensorCopySync " << src.dims() << " from " << src.place()
310+
VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place()
311311
<< " to " << dst_place;
312312
src.check_memory_size();
313313
dst->Resize(src.dims());

paddle/fluid/pybind/inference_api.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,15 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT
345345
static_cast<int64_t *>(paddle_tensor.data<int64_t>()),
346346
shape,
347347
ToPaddleInferPlace(paddle_tensor.place().GetType()));
348+
} else if (paddle_tensor.dtype() == phi::DataType::UINT8) {
349+
tensor.ShareExternalData(
350+
static_cast<uint8_t *>(paddle_tensor.data()),
351+
shape,
352+
ToPaddleInferPlace(paddle_tensor.place().GetType()));
348353
} else {
349354
PADDLE_THROW(platform::errors::Unimplemented(
350355
"Unsupported data type. Now share_external_data only supports INT32, "
351-
"INT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
356+
"INT64, UINT8, FLOAT32, FLOAT16, BFLOAT16 and BOOL."));
352357
}
353358
}
354359

paddle/phi/kernels/fusion/gpu/mmha_util.cu.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,7 @@ inline __device__ uint8_t round_tmp(float16 val) {
16331633
return static_cast<uint8_t>(quant_value + 128.0);
16341634
}
16351635

1636+
#ifdef ENABLE_BF16
16361637
template <>
16371638
inline __device__ uint8_t round_tmp(__nv_bfloat16 val) {
16381639
float quant_value =
@@ -1641,6 +1642,7 @@ inline __device__ uint8_t round_tmp(__nv_bfloat16 val) {
16411642
quant_value = quant_value < -127.0f ? -127.0f : quant_value;
16421643
return static_cast<uint8_t>(quant_value + 128.0);
16431644
}
1645+
#endif
16441646

16451647
template <>
16461648
inline __device__ uint16_t round_tmp(float2 val) {
@@ -1726,6 +1728,7 @@ inline __device__ uint64_t round_tmp(uint4 val) {
17261728
return ret;
17271729
}
17281730

1731+
#ifdef ENABLE_BF16
17291732
template <>
17301733
inline __device__ uint16_t round_tmp(__nv_bfloat162 val) {
17311734
union {
@@ -1760,6 +1763,7 @@ inline __device__ uint64_t round_tmp(bf16_8_t val) {
17601763
int16[3] = round_tmp<uint16_t, __nv_bfloat162>(val.w);
17611764
return int64;
17621765
}
1766+
#endif
17631767

17641768
inline __device__ float2 rotary_embedding_coefficient(const int zid,
17651769
const int rot_embed_dim,

paddle/phi/kernels/gpu/c_embedding_kernel.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,15 @@ __global__ void CEmbedding(T* out,
4141
CUDA_KERNEL_LOOP(i, limit) {
4242
size_t row = i / columns;
4343
size_t col = i % columns;
44-
auto id = ids[row];
44+
auto id = static_cast<int64_t>(ids[row]);
4545

46-
PADDLE_ENFORCE(
47-
id >= 0 && (vocab_size < 0 || id < vocab_size),
48-
"The index is out of bounds, "
49-
"please check whether the dimensions of index and "
50-
"input meet the requirements. It should "
51-
"be less than [%d] and greater than or equal to 0, but received [%d]",
52-
vocab_size,
53-
id);
46+
// PADDLE_ENFORCE(
47+
// id >= 0 && (vocab_size < 0 || id < vocab_size),
48+
// "The index is out of bounds, "
49+
// "please check whether the dimensions of index and "
50+
// "input meet the requirements. It should "
51+
// "be less than [%d] and greater than or equal to 0, but received
52+
// [%d]", vocab_size, id);
5453
if (id >= start_idx && id < end_idx) {
5554
auto real_idx = id - start_idx;
5655
out[i] = table[real_idx * columns + col];

test/legacy_test/test_block_multihead_attention.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def naive_attention_impl(
8888
scale=1.0,
8989
cache_k_dequant_scales=None,
9090
cache_v_dequant_scales=None,
91+
use_cachekv_int8="None",
9192
):
9293
batch = query.shape[0]
9394
heads = query.shape[1]
@@ -98,13 +99,18 @@ def naive_attention_impl(
9899
key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
99100
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
100101
key = key.reshape([batch, heads, seq_len, head_dim])
102+
103+
if use_cachekv_int8 == "dynamic":
104+
unsqueeze_shape = [2, 3]
105+
elif use_cachekv_int8 == "static":
106+
unsqueeze_shape = [0, 2, 3]
101107
if pre_cache_k is not None:
102108
key = paddle.concat([pre_cache_k, key], axis=2)
103109
if cache_k is not None:
104110
if cache_k_dequant_scales is not None:
105111
dequant_cache_k = (
106112
(cache_k.astype('float32') - 128.0)
107-
* cache_k_dequant_scales.unsqueeze([0, 2, 3])
113+
* cache_k_dequant_scales.unsqueeze(unsqueeze_shape)
108114
).astype(key.dtype)
109115
key = paddle.concat([dequant_cache_k, key], axis=2)
110116
else:
@@ -119,7 +125,7 @@ def naive_attention_impl(
119125
if cache_v_dequant_scales is not None:
120126
dequant_cache_v = (
121127
(cache_v.astype('float32') - 128.0)
122-
* cache_v_dequant_scales.unsqueeze([0, 2, 3])
128+
* cache_v_dequant_scales.unsqueeze(unsqueeze_shape)
123129
).astype(value.dtype)
124130
value = paddle.concat([dequant_cache_v, value], axis=2)
125131
else:
@@ -1306,6 +1312,13 @@ def test_all(self):
13061312
)
13071313

13081314

1315+
@unittest.skipIf(
1316+
not core.is_compiled_with_cuda()
1317+
or get_cuda_version() < 11040
1318+
or not is_sm_supported,
1319+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
1320+
"and device's compute capability must be 8.x or 90",
1321+
)
13091322
class TestBlockMultiHeadAttnEncDecPTQDequant(unittest.TestCase):
13101323
def setUp(self):
13111324
paddle.disable_static()
@@ -1641,6 +1654,13 @@ def test_all(self):
16411654
)
16421655

16431656

1657+
@unittest.skipIf(
1658+
not core.is_compiled_with_cuda()
1659+
or get_cuda_version() < 11040
1660+
or not is_sm_supported,
1661+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
1662+
"and device's compute capability must be 8.x or 90",
1663+
)
16441664
class TestBlockMultiHeadAttnEncDecPTQDequantQuantShiftSmooth(unittest.TestCase):
16451665
def setUp(self):
16461666
paddle.disable_static()
@@ -2013,6 +2033,13 @@ def test_all(self):
20132033
)
20142034

20152035

2036+
@unittest.skipIf(
2037+
not core.is_compiled_with_cuda()
2038+
or get_cuda_version() < 11040
2039+
or not is_sm_supported,
2040+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2041+
"and device's compute capability must be 8.x or 90",
2042+
)
20162043
class TestBlockMultiHeadAttnEncDecQuant(unittest.TestCase):
20172044
def setUp(self):
20182045
paddle.disable_static()
@@ -2282,6 +2309,13 @@ def test_all(self):
22822309
)
22832310

22842311

2312+
@unittest.skipIf(
2313+
not core.is_compiled_with_cuda()
2314+
or get_cuda_version() < 11040
2315+
or not is_sm_supported,
2316+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2317+
"and device's compute capability must be 8.x or 90",
2318+
)
22852319
class TestBlockMultiHeadAttnEncDecCacheKVDynamicQuant(unittest.TestCase):
22862320
def setUp(self):
22872321
paddle.disable_static()
@@ -2339,16 +2373,16 @@ def setUp(self):
23392373
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype='uint8')
23402374
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype='uint8')
23412375
self.cache_k_quant_scales = paddle.zeros(
2342-
shape=[self.num_head], dtype='float32'
2376+
shape=[self.batch_size, self.num_head], dtype='float32'
23432377
)
23442378
self.cache_v_quant_scales = paddle.zeros(
2345-
shape=[self.num_head], dtype='float32'
2379+
shape=[self.batch_size, self.num_head], dtype='float32'
23462380
)
23472381
self.cache_k_dequant_scales = paddle.zeros(
2348-
shape=[self.num_head], dtype='float32'
2382+
shape=[self.batch_size, self.num_head], dtype='float32'
23492383
)
23502384
self.cache_v_dequant_scales = paddle.zeros(
2351-
shape=[self.num_head], dtype='float32'
2385+
shape=[self.batch_size, self.num_head], dtype='float32'
23522386
)
23532387

23542388
self.block_tables = paddle.zeros(
@@ -2510,6 +2544,7 @@ def test_all(self):
25102544
self.scale,
25112545
cache_k_dequant_scales=self.cache_k_dequant_scales,
25122546
cache_v_dequant_scales=self.cache_v_dequant_scales,
2547+
use_cachekv_int8="dynamic",
25132548
)
25142549
.transpose([0, 2, 1, 3])
25152550
.reshape([self.batch_size, -1])
@@ -2555,6 +2590,13 @@ def test_all(self):
25552590
)
25562591

25572592

2593+
@unittest.skipIf(
2594+
not core.is_compiled_with_cuda()
2595+
or get_cuda_version() < 11040
2596+
or not is_sm_supported,
2597+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
2598+
"and device's compute capability must be 8.x or 90",
2599+
)
25582600
class TestBlockMultiHeadAttnEncDecCacheKVStaticQuant(unittest.TestCase):
25592601
def setUp(self):
25602602
paddle.disable_static()
@@ -2795,6 +2837,7 @@ def test_all(self):
27952837
self.scale,
27962838
cache_k_dequant_scales=self.cache_k_dequant_scales,
27972839
cache_v_dequant_scales=self.cache_v_dequant_scales,
2840+
use_cachekv_int8="static",
27982841
)
27992842
.transpose([0, 2, 1, 3])
28002843
.reshape([self.batch_size, -1])

0 commit comments

Comments
 (0)