Skip to content

Commit e16bd30

Browse files
committed
make deep_ep's ll_internode using nvlink when intranode
1 parent 21a4207 commit e16bd30

File tree

2 files changed

+89
-22
lines changed
  • paddle/fluid/distributed/collective/deep_ep/kernels
  • python/paddle/distributed/communication/deep_ep

2 files changed

+89
-22
lines changed

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -225,13 +225,33 @@ __global__ __launch_bounds__(
225225
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
226226
slot_idx * num_bytes_per_msg;
227227
if (dst_rank != rank) {
228-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
229-
src_ptr,
230-
num_bytes_per_msg,
231-
dst_rank,
232-
dst_expert_local_idx,
233-
lane_id,
234-
slot_idx);
228+
void* peer_base_addr = reinterpret_cast<void*>(
229+
__ldg(reinterpret_cast<const uint64_t*>(
230+
nvshmemi_device_state_d.peer_heap_base_p2p) +
231+
dst_rank));
232+
if (peer_base_addr) {
233+
char* req_rptr_actual =
234+
reinterpret_cast<char*>(peer_base_addr) +
235+
(reinterpret_cast<char*>(dst_ptr) -
236+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
237+
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
238+
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
239+
UNROLLED_WARP_COPY(8,
240+
lane_id,
241+
num_int4_per_msg,
242+
dst_int4_ptr,
243+
src_int4_ptr,
244+
ld_nc_global,
245+
st_na_global);
246+
} else {
247+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
248+
src_ptr,
249+
num_bytes_per_msg,
250+
dst_rank,
251+
dst_expert_local_idx,
252+
lane_id,
253+
slot_idx);
254+
}
235255
} else {
236256
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
237257
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
@@ -313,11 +333,24 @@ __global__ __launch_bounds__(
313333
responsible_expert_idx) != FINISHED_SUM_TAG * 2) {
314334
}
315335
if (dst_rank != rank) {
316-
nvshmemi_ibgda_amo_nonfetch_add(
317-
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
318-
-num_tokens_sent - 1,
319-
dst_rank,
320-
dst_expert_local_idx);
336+
void* peer_base_addr = reinterpret_cast<void*>(
337+
__ldg(reinterpret_cast<const uint64_t*>(
338+
nvshmemi_device_state_d.peer_heap_base_p2p) +
339+
dst_rank));
340+
if (peer_base_addr) { // P2P enabled
341+
int* rptr_actual = reinterpret_cast<int*>(
342+
reinterpret_cast<char*>(peer_base_addr) +
343+
(reinterpret_cast<char*>(rdma_recv_count +
344+
dst_expert_local_idx * num_ranks + rank) -
345+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
346+
st_na_release(rptr_actual, -num_tokens_sent - 1);
347+
} else {
348+
nvshmemi_ibgda_amo_nonfetch_add(
349+
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
350+
-num_tokens_sent - 1,
351+
dst_rank,
352+
dst_expert_local_idx);
353+
}
321354
} else {
322355
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
323356
-num_tokens_sent - 1);
@@ -635,13 +668,32 @@ __global__ __launch_bounds__(
635668
x_int4,
636669
ld_nc_global,
637670
st_na_global);
638-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
639-
buf_ptr,
640-
hidden * sizeof(nv_bfloat16),
641-
dst_rank,
642-
local_expert_idx,
643-
lane_id,
644-
token_idx - offset);
671+
void* peer_base_addr = reinterpret_cast<void*>(
672+
__ldg(reinterpret_cast<const uint64_t*>(
673+
nvshmemi_device_state_d.peer_heap_base_p2p) +
674+
dst_rank));
675+
if (peer_base_addr) {
676+
char* req_rptr_actual =
677+
reinterpret_cast<char*>(peer_base_addr) +
678+
(reinterpret_cast<char*>(dst_ptr) -
679+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
680+
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
681+
UNROLLED_WARP_COPY(7,
682+
lane_id,
683+
hidden_bf16_int4,
684+
dst_int4_ptr,
685+
x_int4,
686+
ld_nc_global,
687+
st_na_global);
688+
} else {
689+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
690+
buf_ptr,
691+
hidden * sizeof(nv_bfloat16),
692+
dst_rank,
693+
local_expert_idx,
694+
lane_id,
695+
token_idx - offset);
696+
}
645697
}
646698
}
647699

@@ -654,8 +706,22 @@ __global__ __launch_bounds__(
654706
while (ld_acquire_global(atomic_clean_flag) == 0) {
655707
}
656708
if (dst_rank != rank) {
657-
nvshmemi_ibgda_amo_nonfetch_add(
658-
rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
709+
void* peer_base_addr = reinterpret_cast<void*>(
710+
__ldg(reinterpret_cast<const uint64_t*>(
711+
nvshmemi_device_state_d.peer_heap_base_p2p) +
712+
dst_rank));
713+
if (peer_base_addr) {
714+
int* req_rptr_actual = reinterpret_cast<int*>(
715+
reinterpret_cast<char*>(peer_base_addr) +
716+
(reinterpret_cast<char*>(rdma_recv_flag + global_expert_idx) -
717+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
718+
st_na_release(req_rptr_actual, 1);
719+
} else {
720+
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx,
721+
1,
722+
dst_rank,
723+
local_expert_idx);
724+
}
659725
} else {
660726
st_na_release(rdma_recv_flag + global_expert_idx, 1);
661727
}

python/paddle/distributed/communication/deep_ep/buffer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def __init__(
108108
# Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA"
109109
if low_latency_mode:
110110
assert num_qps_per_rank > 0
111-
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
111+
if not os.getenv("NVSHMEM_DISABLE_P2P"):
112+
os.environ['NVSHMEM_DISABLE_P2P'] = '1'
112113
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
113114
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
114115
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = (

0 commit comments

Comments
 (0)