Skip to content

Commit 156602d

Browse files
authored
【Inference】Add support for deep_ep low_latency_dispatch BF16 (PaddlePaddle#72036)
【Inference】Add support for deep_ep low_latency_dispatch BF16 (PaddlePaddle#72036)
1 parent 9110209 commit 156602d

File tree

6 files changed

+436
-212
lines changed

6 files changed

+436
-212
lines changed

paddle/fluid/distributed/collective/deep_ep/config.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ struct LowLatencyLayout {
189189
// Message sizes
190190
EP_HOST_ASSERT(num_scales * static_cast<int64_t>(sizeof(float)) <= hidden);
191191
size_t num_bytes_per_dispatch_msg =
192-
hidden + num_scales * sizeof(float) + sizeof(int4);
192+
sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16),
193+
hidden + num_scales * sizeof(float));
194+
193195
size_t num_bytes_per_combine_msg =
194196
sizeof(int4) + hidden * sizeof(nv_bfloat16);
195197

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
16271627

16281628
#ifdef PADDLE_WITH_NVSHMEM
16291629
std::tuple<deep_ep::detail::Tensor,
1630-
deep_ep::detail::Tensor,
1630+
std::optional<deep_ep::detail::Tensor>,
16311631
deep_ep::detail::Tensor,
16321632
deep_ep::detail::Tensor,
16331633
deep_ep::detail::Tensor,
@@ -1637,6 +1637,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
16371637
const deep_ep::detail::Tensor& topk_idx,
16381638
int num_max_dispatch_tokens_per_rank,
16391639
int num_experts,
1640+
bool use_fp8,
16401641
bool async,
16411642
bool return_recv_hook) {
16421643
EP_HOST_ASSERT(low_latency_mode);
@@ -1675,12 +1676,13 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
16751676
if (!return_recv_hook) stream_wait(launch_stream, compute_stream);
16761677

16771678
// Allocate packed tensors
1678-
auto packed_recv_x = ConvertPaddleTensorToDetailTensor(
1679-
paddle::experimental::empty({num_local_experts,
1680-
num_ranks * num_max_dispatch_tokens_per_rank,
1681-
hidden},
1682-
phi::DataType::FLOAT8_E4M3FN,
1683-
x.place()));
1679+
auto packed_recv_x =
1680+
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
1681+
{num_local_experts,
1682+
num_ranks * num_max_dispatch_tokens_per_rank,
1683+
hidden},
1684+
use_fp8 ? phi::DataType::FLOAT8_E4M3FN : phi::DataType::BFLOAT16,
1685+
x.place()));
16841686
auto packed_recv_src_info =
16851687
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
16861688
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
@@ -1695,25 +1697,32 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
16951697
{num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id)));
16961698

16971699
// Allocate column-majored scales
1698-
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
1699-
"TMA requires the number of tokens to be multiple of 4");
1700-
auto packed_recv_x_scales =
1701-
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
1702-
{num_local_experts,
1703-
num_scales,
1704-
num_ranks * num_max_dispatch_tokens_per_rank},
1705-
phi::DataType::FLOAT32,
1706-
phi::GPUPlace(device_id)));
1707-
packed_recv_x_scales =
1708-
ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose(
1709-
ConvertDetailTensorToPaddleTensor(packed_recv_x_scales),
1710-
std::vector<int>{1, 2}));
1700+
auto packed_recv_x_scales = std::optional<deep_ep::detail::Tensor>();
1701+
1702+
float* packed_recv_x_scales_ptr = nullptr;
1703+
1704+
if (use_fp8) {
1705+
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
1706+
"TMA requires the number of tokens to be multiple of 4");
1707+
packed_recv_x_scales =
1708+
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
1709+
{num_local_experts,
1710+
num_scales,
1711+
num_ranks * num_max_dispatch_tokens_per_rank},
1712+
phi::DataType::FLOAT32,
1713+
phi::GPUPlace(device_id)));
1714+
packed_recv_x_scales =
1715+
ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose(
1716+
ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()),
1717+
std::vector<int>{0, 2, 1}));
1718+
packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr<float>();
1719+
}
17111720

17121721
// Kernel launch
17131722
auto next_clean_meta = next_buffer.clean_meta();
17141723
auto launcher = [=](int phases) {
17151724
internode_ll::dispatch(packed_recv_x.data_ptr(),
1716-
packed_recv_x_scales.data_ptr<float>(),
1725+
packed_recv_x_scales_ptr,
17171726
packed_recv_src_info.data_ptr<int>(),
17181727
packed_recv_layout_range.data_ptr<int64_t>(),
17191728
packed_recv_count.data_ptr<int>(),
@@ -1731,6 +1740,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
17311740
num_experts,
17321741
rank,
17331742
num_ranks,
1743+
use_fp8,
17341744
workspace,
17351745
launch_stream,
17361746
phases);
@@ -2092,7 +2102,7 @@ Buffer::internode_combine_api(
20922102
}
20932103

20942104
std::tuple<paddle::Tensor,
2095-
paddle::Tensor,
2105+
std::optional<paddle::Tensor>,
20962106
paddle::Tensor,
20972107
paddle::Tensor,
20982108
paddle::Tensor,
@@ -2102,6 +2112,7 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
21022112
const paddle::Tensor& topk_idx,
21032113
int num_max_dispatch_tokens_per_rank,
21042114
int num_experts,
2115+
bool use_fp8,
21052116
bool async,
21062117
bool return_recv_hook) {
21072118
#ifdef PADDLE_WITH_NVSHMEM
@@ -2112,12 +2123,18 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
21122123
topk_idx_,
21132124
num_max_dispatch_tokens_per_rank,
21142125
num_experts,
2126+
use_fp8,
21152127
async,
21162128
return_recv_hook);
21172129

21182130
auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res));
2119-
auto packed_recv_x_scales_ =
2120-
ConvertDetailTensorToPaddleTensor(std::get<1>(res));
2131+
2132+
std::optional<paddle::Tensor> packed_recv_x_scales_;
2133+
if (std::get<1>(res).has_value()) {
2134+
packed_recv_x_scales_ =
2135+
ConvertDetailTensorToPaddleTensor(std::get<1>(res).value());
2136+
}
2137+
21212138
auto packed_recv_count_ = ConvertDetailTensorToPaddleTensor(std::get<2>(res));
21222139
auto packed_recv_src_info_ =
21232140
ConvertDetailTensorToPaddleTensor(std::get<3>(res));

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ struct Buffer {
254254

255255
#ifdef PADDLE_WITH_NVSHMEM
256256
std::tuple<deep_ep::detail::Tensor,
257-
deep_ep::detail::Tensor,
257+
std::optional<deep_ep::detail::Tensor>,
258258
deep_ep::detail::Tensor,
259259
deep_ep::detail::Tensor,
260260
deep_ep::detail::Tensor,
@@ -264,6 +264,7 @@ struct Buffer {
264264
const deep_ep::detail::Tensor& topk_idx,
265265
int num_max_dispatch_tokens_per_rank,
266266
int num_experts,
267+
bool use_fp8,
267268
bool async,
268269
bool return_recv_hook);
269270

@@ -335,7 +336,7 @@ struct Buffer {
335336
bool allocate_on_comm_stream);
336337

337338
std::tuple<paddle::Tensor,
338-
paddle::Tensor,
339+
std::optional<paddle::Tensor>,
339340
paddle::Tensor,
340341
paddle::Tensor,
341342
paddle::Tensor,
@@ -345,6 +346,7 @@ struct Buffer {
345346
const paddle::Tensor& topk_idx,
346347
int num_max_dispatch_tokens_per_rank,
347348
int num_experts,
349+
bool use_fp8,
348350
bool async,
349351
bool return_recv_hook);
350352

0 commit comments

Comments
 (0)