Skip to content

No.2 to Support DEEP_EP_BF16 & Fix BUG #72065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,10 @@ LOW_LATENCY_COMBINE_RECV:
cg::this_grid().sync();

// Reduce tokens with FP8 cast
EP_DEVICE_ASSERT(num_topk <= 32 && hidden_bf16_int4 <= num_threads);
// EP_DEVICE_ASSERT(num_topk <= 32 && hidden_bf16_int4 <= num_threads);
EP_STATIC_ASSERT(kHidden % (32 * kNumElemsPerInt4) == 0,
"Invalid vectorization");
if (thread_id < hidden_bf16_int4) {
for (int g_id = thread_id; g_id < hidden_bf16_int4; g_id += num_threads) {
for (int token_idx = sm_id; token_idx < num_combined_tokens;
token_idx += num_sms) {
// Read top-k indices and weights
Expand All @@ -718,7 +718,7 @@ LOW_LATENCY_COMBINE_RECV:

// Reduce
auto x_vec = ld_nc_global(
reinterpret_cast<const int4*>(rdma_buffer_row) + thread_id);
reinterpret_cast<const int4*>(rdma_buffer_row) + g_id);
const auto x_bf16 = reinterpret_cast<nv_bfloat16*>(&x_vec);
#pragma unroll
for (int j = 0; j < kNumElemsPerInt4; ++j)
Expand All @@ -733,7 +733,7 @@ LOW_LATENCY_COMBINE_RECV:
for (int j = 0; j < kNumElemsPerInt4; ++j)
combined_bf16[j] = static_cast<nv_bfloat16>(combined_values[j]);
(reinterpret_cast<int4*>(combined_x) +
token_idx * hidden_bf16_int4)[thread_id] = combined_int4;
token_idx * hidden_bf16_int4)[g_id] = combined_int4;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@

#define SWITCH_HIDDEN(case_macro) \
switch (hidden) { \
case 2048: case_macro(2048); \
case 2560: case_macro(2560); \
case 5120: case_macro(5120); \
case 7168: case_macro(7168); \
case 8192: case_macro(8192); \
default: EP_HOST_ASSERT(false && "Unsupported hidden"); \
} while (false)
1 change: 1 addition & 0 deletions paddle/fluid/pybind/deep_ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <Python.h>
#include "pybind11/functional.h"
#include "pybind11/stl.h"

#ifdef PADDLE_WITH_DEEP_EP
Expand Down
Loading