Skip to content

Commit 361ff30

Browse files
Update deep_ep::internode_ll and add barrier function
1 parent 585e852 commit 361ff30

File tree

7 files changed

+105
-41
lines changed

7 files changed

+105
-41
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
16251625
#endif
16261626
}
16271627

1628+
void Buffer::barrier_all() { internode_ll::barrier_all(calc_ctx->stream()); }
1629+
16281630
#ifdef PADDLE_WITH_NVSHMEM
16291631
std::tuple<deep_ep::detail::Tensor,
16301632
std::optional<deep_ep::detail::Tensor>,

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ struct Buffer {
251251
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
252252
int hidden,
253253
int num_experts);
254+
void barrier_all();
254255

255256
#ifdef PADDLE_WITH_NVSHMEM
256257
std::tuple<deep_ep::detail::Tensor,

paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ void combine(cudaDataType_t type,
288288
// Internode low-latency kernels
289289
namespace internode_ll {
290290

291+
void barrier_all(cudaStream_t stream);
291292
void clean_low_latency_buffer(int* clean_0,
292293
int num_clean_int_0,
293294
int* clean_1,

paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ uint16_t HtoBE16(uint16_t x) {
8282

8383
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
8484

85+
typedef struct {
86+
uint32_t add_data;
87+
uint32_t field_boundary;
88+
uint64_t reserved;
89+
} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t;
90+
8591
__device__ static __forceinline__
8692
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
8793
return &nvshmemi_ibgda_device_state_d;
@@ -439,4 +445,61 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
439445
__syncwarp();
440446
}
441447

448+
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
449+
nvshmemi_ibgda_device_qp_t *qp, const int &value,
450+
uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
451+
uint16_t wqe_idx, void **out_wqes) {
452+
ibgda_ctrl_seg_t ctrl_seg = {0};
453+
struct mlx5_wqe_raddr_seg raddr_seg;
454+
struct mlx5_wqe_atomic_seg atomic_seg_1;
455+
struct mlx5_wqe_data_seg data_seg;
456+
457+
auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
458+
auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
459+
auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
460+
auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr));
461+
462+
raddr_seg.raddr = HtoBE64(raddr);
463+
raddr_seg.rkey = rkey;
464+
raddr_seg.reserved = 0;
465+
466+
// NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
467+
ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000);
468+
auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1);
469+
atomic_32_masked_fa_seg->add_data = HtoBE32(value);
470+
atomic_32_masked_fa_seg->field_boundary = 0;
471+
472+
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4);
473+
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
474+
475+
data_seg.byte_count = HtoBE32(sizeof(int));
476+
data_seg.lkey = lkey;
477+
data_seg.addr = HtoBE64(laddr);
478+
479+
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization");
480+
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization");
481+
EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization");
482+
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization");
483+
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg));
484+
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg));
485+
st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1));
486+
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
487+
}
488+
489+
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
490+
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
491+
492+
__be32 rkey;
493+
uint64_t raddr;
494+
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
495+
496+
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
497+
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
498+
499+
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
500+
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
501+
502+
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
503+
}
504+
442505
} // namespace deep_ep

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ namespace deep_ep {
3434

3535
namespace internode_ll {
3636

37+
__global__ void barrier_all() { nvshmemx_barrier_all_block(); }
38+
39+
void barrier_all(cudaStream_t stream) {
40+
constexpr int kNumThreads = 1;
41+
42+
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
43+
LAUNCH_KERNEL(&cfg, barrier_all);
44+
}
45+
3746
template <int kNumThreads>
3847
__launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(
3948
int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1) {
@@ -112,7 +121,6 @@ __global__ __launch_bounds__(
112121

113122
// Message package: hidden data, FP8 scales, index at source
114123
// NOTES: currently we have 3 reserved int fields for future use
115-
116124
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
117125
const size_t num_bytes_per_msg =
118126
sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float))
@@ -305,13 +313,11 @@ __global__ __launch_bounds__(
305313
responsible_expert_idx) != FINISHED_SUM_TAG * 2) {
306314
}
307315
if (dst_rank != rank) {
308-
nvshmemi_ibgda_rma_p(
316+
nvshmemi_ibgda_amo_nonfetch_add(
309317
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
310318
-num_tokens_sent - 1,
311319
dst_rank,
312-
dst_expert_local_idx,
313-
0);
314-
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
320+
dst_expert_local_idx);
315321
} else {
316322
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
317323
-num_tokens_sent - 1);
@@ -366,16 +372,10 @@ LOW_LATENCY_DISPATCH_RECV:
366372
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1,
367373
"Requires more than one warp per group");
368374
if (sub_warp_id == 1 && lane_id == 0) {
369-
if (src_rank != rank) {
370-
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
371-
num_recv_tokens = ld_acquire_sys_global(
372-
rdma_recv_count + local_expert_idx * num_ranks + src_rank);
373-
EP_DEVICE_ASSERT(num_recv_tokens != 0);
374-
} else {
375-
while ((num_recv_tokens = ld_acquire_global(
376-
rdma_recv_count + local_expert_idx * num_ranks +
377-
src_rank)) == 0) {
378-
}
375+
// printf("enter recv meta\n");
376+
while ((num_recv_tokens = ld_acquire_global(
377+
rdma_recv_count + local_expert_idx * num_ranks + src_rank)) ==
378+
0) {
379379
}
380380
num_recv_tokens = -num_recv_tokens - 1;
381381
recv_token_begin_idx =
@@ -539,7 +539,8 @@ __global__ __launch_bounds__(
539539
int num_experts,
540540
int rank,
541541
int num_ranks,
542-
int phases) {
542+
int phases,
543+
bool zero_copy) {
543544
const auto sm_id = static_cast<int>(blockIdx.x);
544545
const auto num_sms = static_cast<int>(gridDim.x);
545546
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -580,7 +581,9 @@ __global__ __launch_bounds__(
580581
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
581582
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
582583
const auto layout =
583-
__ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
584+
__ldg(layout_range + local_expert_idx * num_ranks +
585+
dst_rank); // num_recv_tokens, recv_token_begin_idx
586+
584587
const auto local_x = reinterpret_cast<const int4*>(x) +
585588
local_expert_idx * num_ranks *
586589
num_max_dispatch_tokens_per_rank *
@@ -625,13 +628,14 @@ __global__ __launch_bounds__(
625628
st_na_global);
626629
} else {
627630
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
628-
UNROLLED_WARP_COPY(7,
629-
lane_id,
630-
hidden_bf16_int4,
631-
buf_int4_ptr,
632-
x_int4,
633-
ld_nc_global,
634-
st_na_global);
631+
if (!zero_copy)
632+
UNROLLED_WARP_COPY(7,
633+
lane_id,
634+
hidden_bf16_int4,
635+
buf_int4_ptr,
636+
x_int4,
637+
ld_nc_global,
638+
st_na_global);
635639
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
636640
buf_ptr,
637641
hidden * sizeof(nv_bfloat16),
@@ -651,11 +655,8 @@ __global__ __launch_bounds__(
651655
while (ld_acquire_global(atomic_clean_flag) == 0) {
652656
}
653657
if (dst_rank != rank) {
654-
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx,
655-
1,
656-
dst_rank,
657-
local_expert_idx,
658-
0);
658+
nvshmemi_ibgda_amo_nonfetch_add(
659+
rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
659660
} else {
660661
st_na_release(rdma_recv_flag + global_expert_idx, 1);
661662
}
@@ -672,18 +673,9 @@ LOW_LATENCY_COMBINE_RECV:
672673
if (responsible_expert_idx < num_experts) {
673674
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1,
674675
"Invalid number of warps per group");
675-
if (sub_warp_id == 0 && lane_id == 0) {
676-
// TODO(Xreki): refactor QP indices
677-
auto src_rank = responsible_expert_idx / num_local_experts;
678-
auto src_expert_idx = responsible_expert_idx % num_local_experts;
679-
if (src_rank != rank) {
680-
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
681-
} else {
682-
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) ==
683-
0) {
684-
}
676+
if (sub_warp_id == 0 && lane_id == 0)
677+
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0) {
685678
}
686-
}
687679
}
688680
cg::this_grid().sync();
689681

@@ -796,7 +788,8 @@ void combine(void* combined_x,
796788
num_experts, \
797789
rank, \
798790
num_ranks, \
799-
phases); \
791+
phases, \
792+
false); \
800793
} \
801794
break
802795

paddle/fluid/pybind/deep_ep_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void BindDeepEPApi(pybind11::module *m) {
8989
.def("intranode_combine", &deep_ep::Buffer::intranode_combine_api)
9090
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch_api)
9191
.def("internode_combine", &deep_ep::Buffer::internode_combine_api)
92+
.def("barrier_all", &deep_ep::Buffer::barrier_all)
9293
.def("clean_low_latency_buffer",
9394
&deep_ep::Buffer::clean_low_latency_buffer)
9495
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch_api)

python/paddle/distributed/communication/deep_ep/buffer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,9 @@ def internode_combine(
736736
)
737737
return combined_x, combined_topk_weights, EventOverlap(event)
738738

739+
def barrier_all(self):
740+
self.runtime.barrier_all()
741+
739742
def clean_low_latency_buffer(
740743
self,
741744
num_max_dispatch_tokens_per_rank: int,

0 commit comments

Comments
 (0)