@@ -34,6 +34,15 @@ namespace deep_ep {
34
34
35
35
namespace internode_ll {
36
36
37
+ __global__ void barrier_all () { nvshmemx_barrier_all_block (); }
38
+
39
+ void barrier_all (cudaStream_t stream) {
40
+ constexpr int kNumThreads = 1 ;
41
+
42
+ SETUP_LAUNCH_CONFIG (1 , kNumThreads , stream);
43
+ LAUNCH_KERNEL (&cfg, barrier_all);
44
+ }
45
+
37
46
template <int kNumThreads >
38
47
__launch_bounds__ (kNumThreads , 1 ) __global__ void clean_low_latency_buffer (
39
48
int * clean_0, int num_clean_int_0, int * clean_1, int num_clean_int_1) {
@@ -112,7 +121,6 @@ __global__ __launch_bounds__(
112
121
113
122
// Message package: hidden data, FP8 scales, index at source
114
123
// NOTES: currently we have 3 reserved int fields for future use
115
-
116
124
using vec_t = typename std::conditional<kUseFP8 , int2 , int4 >::type;
117
125
const size_t num_bytes_per_msg =
118
126
sizeof (int4 ) + (kUseFP8 ? (kHidden + num_scales * sizeof (float ))
@@ -305,13 +313,11 @@ __global__ __launch_bounds__(
305
313
responsible_expert_idx) != FINISHED_SUM_TAG * 2 ) {
306
314
}
307
315
if (dst_rank != rank) {
308
- nvshmemi_ibgda_rma_p (
316
+ nvshmemi_ibgda_amo_nonfetch_add (
309
317
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
310
318
-num_tokens_sent - 1 ,
311
319
dst_rank,
312
- dst_expert_local_idx,
313
- 0 );
314
- nvshmemi_ibgda_prepare_recvs (dst_rank, dst_expert_local_idx);
320
+ dst_expert_local_idx);
315
321
} else {
316
322
st_na_release (rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
317
323
-num_tokens_sent - 1 );
@@ -366,16 +372,9 @@ LOW_LATENCY_DISPATCH_RECV:
366
372
EP_STATIC_ASSERT (kNumWarpsPerGroup > 1 ,
367
373
" Requires more than one warp per group" );
368
374
if (sub_warp_id == 1 && lane_id == 0 ) {
369
- if (src_rank != rank) {
370
- nvshmemi_ibgda_poll_recv (src_rank, local_expert_idx);
371
- num_recv_tokens = ld_acquire_sys_global (
372
- rdma_recv_count + local_expert_idx * num_ranks + src_rank);
373
- EP_DEVICE_ASSERT (num_recv_tokens != 0 );
374
- } else {
375
- while ((num_recv_tokens = ld_acquire_global (
376
- rdma_recv_count + local_expert_idx * num_ranks +
377
- src_rank)) == 0 ) {
378
- }
375
+ while ((num_recv_tokens = ld_acquire_global (
376
+ rdma_recv_count + local_expert_idx * num_ranks + src_rank)) ==
377
+ 0 ) {
379
378
}
380
379
num_recv_tokens = -num_recv_tokens - 1 ;
381
380
recv_token_begin_idx =
@@ -539,7 +538,8 @@ __global__ __launch_bounds__(
539
538
int num_experts,
540
539
int rank,
541
540
int num_ranks,
542
- int phases) {
541
+ int phases,
542
+ bool zero_copy) {
543
543
const auto sm_id = static_cast <int >(blockIdx .x );
544
544
const auto num_sms = static_cast <int >(gridDim .x );
545
545
const auto thread_id = static_cast <int >(threadIdx .x );
@@ -580,7 +580,9 @@ __global__ __launch_bounds__(
580
580
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
581
581
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
582
582
const auto layout =
583
- __ldg (layout_range + local_expert_idx * num_ranks + dst_rank);
583
+ __ldg (layout_range + local_expert_idx * num_ranks +
584
+ dst_rank); // num_recv_tokens, recv_token_begin_idx
585
+
584
586
const auto local_x = reinterpret_cast <const int4 *>(x) +
585
587
local_expert_idx * num_ranks *
586
588
num_max_dispatch_tokens_per_rank *
@@ -625,13 +627,14 @@ __global__ __launch_bounds__(
625
627
st_na_global);
626
628
} else {
627
629
const auto buf_int4_ptr = reinterpret_cast <int4 *>(buf_ptr);
628
- UNROLLED_WARP_COPY (7 ,
629
- lane_id,
630
- hidden_bf16_int4,
631
- buf_int4_ptr,
632
- x_int4,
633
- ld_nc_global,
634
- st_na_global);
630
+ if (!zero_copy)
631
+ UNROLLED_WARP_COPY (7 ,
632
+ lane_id,
633
+ hidden_bf16_int4,
634
+ buf_int4_ptr,
635
+ x_int4,
636
+ ld_nc_global,
637
+ st_na_global);
635
638
nvshmemi_ibgda_put_nbi_warp (dst_ptr,
636
639
buf_ptr,
637
640
hidden * sizeof (nv_bfloat16),
@@ -651,11 +654,8 @@ __global__ __launch_bounds__(
651
654
while (ld_acquire_global (atomic_clean_flag) == 0 ) {
652
655
}
653
656
if (dst_rank != rank) {
654
- nvshmemi_ibgda_rma_p (rdma_recv_flag + global_expert_idx,
655
- 1 ,
656
- dst_rank,
657
- local_expert_idx,
658
- 0 );
657
+ nvshmemi_ibgda_amo_nonfetch_add (
658
+ rdma_recv_flag + global_expert_idx, 1 , dst_rank, local_expert_idx);
659
659
} else {
660
660
st_na_release (rdma_recv_flag + global_expert_idx, 1 );
661
661
}
@@ -672,18 +672,9 @@ LOW_LATENCY_COMBINE_RECV:
672
672
if (responsible_expert_idx < num_experts) {
673
673
EP_STATIC_ASSERT (kNumWarpsPerGroup > 1 ,
674
674
" Invalid number of warps per group" );
675
- if (sub_warp_id == 0 && lane_id == 0 ) {
676
- // TODO(Xreki): refactor QP indices
677
- auto src_rank = responsible_expert_idx / num_local_experts;
678
- auto src_expert_idx = responsible_expert_idx % num_local_experts;
679
- if (src_rank != rank) {
680
- nvshmemi_ibgda_poll_recv (src_rank, src_expert_idx);
681
- } else {
682
- while (ld_acquire_global (rdma_recv_flag + responsible_expert_idx) ==
683
- 0 ) {
684
- }
675
+ if (sub_warp_id == 0 && lane_id == 0 )
676
+ while (ld_acquire_global (rdma_recv_flag + responsible_expert_idx) == 0 ) {
685
677
}
686
- }
687
678
}
688
679
cg::this_grid ().sync ();
689
680
@@ -796,7 +787,8 @@ void combine(void* combined_x,
796
787
num_experts, \
797
788
rank, \
798
789
num_ranks, \
799
- phases); \
790
+ phases, \
791
+ false ); \
800
792
} \
801
793
break
802
794
0 commit comments