@@ -225,13 +225,33 @@ __global__ __launch_bounds__(
225
225
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
226
226
slot_idx * num_bytes_per_msg;
227
227
if (dst_rank != rank) {
228
- nvshmemi_ibgda_put_nbi_warp (dst_ptr,
229
- src_ptr,
230
- num_bytes_per_msg,
231
- dst_rank,
232
- dst_expert_local_idx,
233
- lane_id,
234
- slot_idx);
228
+ void * peer_base_addr = reinterpret_cast <void *>(
229
+ __ldg (reinterpret_cast <const uint64_t *>(
230
+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
231
+ dst_rank));
232
+ if (peer_base_addr) {
233
+ char * req_rptr_actual =
234
+ reinterpret_cast <char *>(peer_base_addr) +
235
+ (reinterpret_cast <char *>(dst_ptr) -
236
+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base ));
237
+ const auto * src_int4_ptr = reinterpret_cast <const int4 *>(src_ptr);
238
+ const auto * dst_int4_ptr = reinterpret_cast <int4 *>(req_rptr_actual);
239
+ UNROLLED_WARP_COPY (8 ,
240
+ lane_id,
241
+ num_int4_per_msg,
242
+ dst_int4_ptr,
243
+ src_int4_ptr,
244
+ ld_nc_global,
245
+ st_na_global);
246
+ } else {
247
+ nvshmemi_ibgda_put_nbi_warp (dst_ptr,
248
+ src_ptr,
249
+ num_bytes_per_msg,
250
+ dst_rank,
251
+ dst_expert_local_idx,
252
+ lane_id,
253
+ slot_idx);
254
+ }
235
255
} else {
236
256
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
237
257
const auto * src_int4_ptr = reinterpret_cast <const int4 *>(src_ptr);
@@ -313,11 +333,24 @@ __global__ __launch_bounds__(
313
333
responsible_expert_idx) != FINISHED_SUM_TAG * 2 ) {
314
334
}
315
335
if (dst_rank != rank) {
316
- nvshmemi_ibgda_amo_nonfetch_add (
317
- rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
318
- -num_tokens_sent - 1 ,
319
- dst_rank,
320
- dst_expert_local_idx);
336
+ void * peer_base_addr = reinterpret_cast <void *>(
337
+ __ldg (reinterpret_cast <const uint64_t *>(
338
+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
339
+ dst_rank));
340
+ if (peer_base_addr) { // P2P enabled
341
+ int * rptr_actual = reinterpret_cast <int *>(
342
+ reinterpret_cast <char *>(peer_base_addr) +
343
+ (reinterpret_cast <char *>(rdma_recv_count +
344
+ dst_expert_local_idx * num_ranks + rank) -
345
+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base )));
346
+ st_na_release (rptr_actual, -num_tokens_sent - 1 );
347
+ } else {
348
+ nvshmemi_ibgda_amo_nonfetch_add (
349
+ rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
350
+ -num_tokens_sent - 1 ,
351
+ dst_rank,
352
+ dst_expert_local_idx);
353
+ }
321
354
} else {
322
355
st_na_release (rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
323
356
-num_tokens_sent - 1 );
@@ -635,13 +668,32 @@ __global__ __launch_bounds__(
635
668
x_int4,
636
669
ld_nc_global,
637
670
st_na_global);
638
- nvshmemi_ibgda_put_nbi_warp (dst_ptr,
639
- buf_ptr,
640
- hidden * sizeof (nv_bfloat16),
641
- dst_rank,
642
- local_expert_idx,
643
- lane_id,
644
- token_idx - offset);
671
+ void * peer_base_addr = reinterpret_cast <void *>(
672
+ __ldg (reinterpret_cast <const uint64_t *>(
673
+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
674
+ dst_rank));
675
+ if (peer_base_addr) {
676
+ char * req_rptr_actual =
677
+ reinterpret_cast <char *>(peer_base_addr) +
678
+ (reinterpret_cast <char *>(dst_ptr) -
679
+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base ));
680
+ const auto dst_int4_ptr = reinterpret_cast <int4 *>(req_rptr_actual);
681
+ UNROLLED_WARP_COPY (7 ,
682
+ lane_id,
683
+ hidden_bf16_int4,
684
+ dst_int4_ptr,
685
+ x_int4,
686
+ ld_nc_global,
687
+ st_na_global);
688
+ } else {
689
+ nvshmemi_ibgda_put_nbi_warp (dst_ptr,
690
+ buf_ptr,
691
+ hidden * sizeof (nv_bfloat16),
692
+ dst_rank,
693
+ local_expert_idx,
694
+ lane_id,
695
+ token_idx - offset);
696
+ }
645
697
}
646
698
}
647
699
@@ -654,8 +706,22 @@ __global__ __launch_bounds__(
654
706
while (ld_acquire_global (atomic_clean_flag) == 0 ) {
655
707
}
656
708
if (dst_rank != rank) {
657
- nvshmemi_ibgda_amo_nonfetch_add (
658
- rdma_recv_flag + global_expert_idx, 1 , dst_rank, local_expert_idx);
709
+ void * peer_base_addr = reinterpret_cast <void *>(
710
+ __ldg (reinterpret_cast <const uint64_t *>(
711
+ nvshmemi_device_state_d.peer_heap_base_p2p ) +
712
+ dst_rank));
713
+ if (peer_base_addr) {
714
+ int * req_rptr_actual = reinterpret_cast <int *>(
715
+ reinterpret_cast <char *>(peer_base_addr) +
716
+ (reinterpret_cast <char *>(rdma_recv_flag + global_expert_idx) -
717
+ reinterpret_cast <char *>(nvshmemi_device_state_d.heap_base )));
718
+ st_na_release (req_rptr_actual, 1 );
719
+ } else {
720
+ nvshmemi_ibgda_amo_nonfetch_add (rdma_recv_flag + global_expert_idx,
721
+ 1 ,
722
+ dst_rank,
723
+ local_expert_idx);
724
+ }
659
725
} else {
660
726
st_na_release (rdma_recv_flag + global_expert_idx, 1 );
661
727
}
0 commit comments