Skip to content

【Inference】Add support for deep_ep low_latency_dispatch BF16 #72036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/collective/deep_ep/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ struct LowLatencyLayout {
// Message sizes
EP_HOST_ASSERT(num_scales * static_cast<int64_t>(sizeof(float)) <= hidden);
size_t num_bytes_per_dispatch_msg =
hidden + num_scales * sizeof(float) + sizeof(int4);
sizeof(int4) + std::max(hidden * sizeof(nv_bfloat16),
hidden + num_scales * sizeof(float));

size_t num_bytes_per_combine_msg =
sizeof(int4) + hidden * sizeof(nv_bfloat16);

Expand Down
65 changes: 41 additions & 24 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,

#ifdef PADDLE_WITH_NVSHMEM
std::tuple<deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
std::optional<deep_ep::detail::Tensor>,
deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
Expand All @@ -1637,6 +1637,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
const deep_ep::detail::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook) {
EP_HOST_ASSERT(low_latency_mode);
Expand Down Expand Up @@ -1675,12 +1676,13 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
if (!return_recv_hook) stream_wait(launch_stream, compute_stream);

// Allocate packed tensors
auto packed_recv_x = ConvertPaddleTensorToDetailTensor(
paddle::experimental::empty({num_local_experts,
num_ranks * num_max_dispatch_tokens_per_rank,
hidden},
phi::DataType::FLOAT8_E4M3FN,
x.place()));
auto packed_recv_x =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_local_experts,
num_ranks * num_max_dispatch_tokens_per_rank,
hidden},
use_fp8 ? phi::DataType::FLOAT8_E4M3FN : phi::DataType::BFLOAT16,
x.place()));
auto packed_recv_src_info =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank},
Expand All @@ -1695,25 +1697,32 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
{num_local_experts}, phi::DataType::INT32, phi::GPUPlace(device_id)));

// Allocate column-majored scales
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
"TMA requires the number of tokens to be multiple of 4");
auto packed_recv_x_scales =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_local_experts,
num_scales,
num_ranks * num_max_dispatch_tokens_per_rank},
phi::DataType::FLOAT32,
phi::GPUPlace(device_id)));
packed_recv_x_scales =
ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose(
ConvertDetailTensorToPaddleTensor(packed_recv_x_scales),
std::vector<int>{1, 2}));
auto packed_recv_x_scales = std::optional<deep_ep::detail::Tensor>();

float* packed_recv_x_scales_ptr = nullptr;

if (use_fp8) {
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 &&
"TMA requires the number of tokens to be multiple of 4");
packed_recv_x_scales =
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
{num_local_experts,
num_scales,
num_ranks * num_max_dispatch_tokens_per_rank},
phi::DataType::FLOAT32,
phi::GPUPlace(device_id)));
packed_recv_x_scales =
ConvertPaddleTensorToDetailTensor(paddle::experimental::transpose(
ConvertDetailTensorToPaddleTensor(packed_recv_x_scales.value()),
std::vector<int>{0, 2, 1}));
packed_recv_x_scales_ptr = packed_recv_x_scales.value().data_ptr<float>();
}

// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
internode_ll::dispatch(packed_recv_x.data_ptr(),
packed_recv_x_scales.data_ptr<float>(),
packed_recv_x_scales_ptr,
packed_recv_src_info.data_ptr<int>(),
packed_recv_layout_range.data_ptr<int64_t>(),
packed_recv_count.data_ptr<int>(),
Expand All @@ -1731,6 +1740,7 @@ Buffer::low_latency_dispatch(const deep_ep::detail::Tensor& x,
num_experts,
rank,
num_ranks,
use_fp8,
workspace,
launch_stream,
phases);
Expand Down Expand Up @@ -2092,7 +2102,7 @@ Buffer::internode_combine_api(
}

std::tuple<paddle::Tensor,
paddle::Tensor,
std::optional<paddle::Tensor>,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
Expand All @@ -2102,6 +2112,7 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
const paddle::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook) {
#ifdef PADDLE_WITH_NVSHMEM
Expand All @@ -2112,12 +2123,18 @@ Buffer::low_latency_dispatch_api(const paddle::Tensor& x,
topk_idx_,
num_max_dispatch_tokens_per_rank,
num_experts,
use_fp8,
async,
return_recv_hook);

auto packed_recv_x_ = ConvertDetailTensorToPaddleTensor(std::get<0>(res));
auto packed_recv_x_scales_ =
ConvertDetailTensorToPaddleTensor(std::get<1>(res));

std::optional<paddle::Tensor> packed_recv_x_scales_;
if (std::get<1>(res).has_value()) {
packed_recv_x_scales_ =
ConvertDetailTensorToPaddleTensor(std::get<1>(res).value());
}

auto packed_recv_count_ = ConvertDetailTensorToPaddleTensor(std::get<2>(res));
auto packed_recv_src_info_ =
ConvertDetailTensorToPaddleTensor(std::get<3>(res));
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct Buffer {

#ifdef PADDLE_WITH_NVSHMEM
std::tuple<deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
std::optional<deep_ep::detail::Tensor>,
deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
deep_ep::detail::Tensor,
Expand All @@ -264,6 +264,7 @@ struct Buffer {
const deep_ep::detail::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook);

Expand Down Expand Up @@ -335,7 +336,7 @@ struct Buffer {
bool allocate_on_comm_stream);

std::tuple<paddle::Tensor,
paddle::Tensor,
std::optional<paddle::Tensor>,
paddle::Tensor,
paddle::Tensor,
paddle::Tensor,
Expand All @@ -345,6 +346,7 @@ struct Buffer {
const paddle::Tensor& topk_idx,
int num_max_dispatch_tokens_per_rank,
int num_experts,
bool use_fp8,
bool async,
bool return_recv_hook);

Expand Down
Loading
Loading