Skip to content

Commit 16cbd5f

Browse files
authored
Move get_dispatch_layout from internode.cu to runtime.cu and enable is for intranode. (PaddlePaddle#71657)
1 parent 7d71132 commit 16cbd5f

File tree

4 files changed

+170
-170
lines changed

4 files changed

+170
-170
lines changed

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,12 @@ Buffer::get_dispatch_layout(const deep_ep::detail::Tensor& topk_idx,
345345
paddle::experimental::empty({num_tokens, num_ranks},
346346
phi::DataType::BOOL,
347347
phi::GPUPlace(device_id)));
348-
#ifdef PADDLE_WITH_NVSHMEM
349348
if (is_internode_available())
350349
num_tokens_per_rdma_rank =
351350
ConvertPaddleTensorToDetailTensor(paddle::experimental::empty(
352351
{num_rdma_ranks}, phi::DataType::INT32, phi::GPUPlace(device_id)));
353352

353+
// get_dispatch_layout is used for both intranode and internode.
354354
internode::get_dispatch_layout(
355355
topk_idx.data_ptr<int64_t>(),
356356
num_tokens_per_rank.data_ptr<int>(),
@@ -364,7 +364,6 @@ Buffer::get_dispatch_layout(const deep_ep::detail::Tensor& topk_idx,
364364
num_ranks,
365365
num_experts,
366366
comm_stream);
367-
#endif
368367

369368
// Wait streams
370369
std::optional<EventHandle> event;

paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh

+9-8
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream
2929

3030
} // namespace intranode
3131

32-
#ifdef PADDLE_WITH_NVSHMEM
3332
// Internode runtime
3433
namespace internode {
3534

35+
#ifdef PADDLE_WITH_NVSHMEM
3636
std::vector<uint8_t> get_unique_id();
3737

3838
int init(const std::vector<uint8_t> &root_unique_id_val, int rank, int num_ranks, bool low_latency_mode);
@@ -44,9 +44,15 @@ void free(void *ptr);
4444
void barrier();
4545

4646
void finalize();
47+
#endif // PADDLE_WITH_NVSHMEM
48+
49+
void get_dispatch_layout(const int64_t* topk_idx,
50+
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
51+
int* num_tokens_per_expert, bool* is_token_in_rank,
52+
int num_tokens, int num_topk, int num_ranks, int num_experts,
53+
cudaStream_t stream);
4754

4855
} // namespace internode
49-
#endif // PADDLE_WITH_NVSHMEM
5056

5157
// Intranode kernels
5258
namespace intranode {
@@ -90,12 +96,6 @@ namespace internode {
9096

9197
int get_source_meta_bytes();
9298

93-
void get_dispatch_layout(const int64_t* topk_idx,
94-
int* num_tokens_per_rank, int* num_tokens_per_rdma_rank,
95-
int* num_tokens_per_expert, bool* is_token_in_rank,
96-
int num_tokens, int num_topk, int num_ranks, int num_experts,
97-
cudaStream_t stream);
98-
9999
void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks,
100100
const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped,
101101
const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts,
@@ -172,4 +172,5 @@ void combine(void* combined_x,
172172

173173
} // namespace internode_ll
174174
#endif // PADDLE_WITH_NVSHMEM
175+
175176
} // namespace deep_ep

paddle/fluid/distributed/collective/deep_ep/kernels/internode.cu

-158
Original file line numberDiff line numberDiff line change
@@ -37,164 +37,6 @@ namespace internode {
3737

3838
extern nvshmem_team_t cpu_rdma_team;
3939

40-
template <int kNumThreads, int kNumExpertsPerSM, int kNumRanksPerSM>
41-
__global__ void __launch_bounds__(kNumThreads, 1)
42-
get_dispatch_layout(const int64_t* topk_idx,
43-
int* num_tokens_per_rank,
44-
int* num_tokens_per_rdma_rank,
45-
int* num_tokens_per_expert,
46-
bool* is_token_in_rank,
47-
int num_tokens,
48-
int num_topk,
49-
int num_ranks,
50-
int num_experts) {
51-
auto sm_id = static_cast<int>(blockIdx.x);
52-
auto thread_id = static_cast<int>(threadIdx.x);
53-
54-
// Count expert statistics
55-
__shared__ int num_tokens_per_expert_per_thread[kNumThreads]
56-
[kNumExpertsPerSM];
57-
int expert_begin_idx = sm_id * kNumExpertsPerSM,
58-
expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts);
59-
if (expert_begin_idx < expert_end_idx) {
60-
// Per-thread count
61-
#pragma unroll
62-
for (int i = 0; i < kNumExpertsPerSM; ++i)
63-
num_tokens_per_expert_per_thread[thread_id][i] = 0;
64-
#pragma unroll
65-
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
66-
auto shifted_topk_idx = topk_idx + i * num_topk;
67-
#pragma unroll
68-
for (int j = 0, expert_idx; j < num_topk; ++j) {
69-
expert_idx = static_cast<int>(shifted_topk_idx[j]);
70-
if (expert_begin_idx <= expert_idx && expert_idx < expert_end_idx)
71-
++num_tokens_per_expert_per_thread[thread_id]
72-
[expert_idx - expert_begin_idx];
73-
}
74-
}
75-
__syncthreads();
76-
77-
// Sum up
78-
EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads,
79-
"Too many experts per SM");
80-
if (expert_begin_idx + thread_id < expert_end_idx) {
81-
int sum = 0;
82-
#pragma unroll
83-
for (int i = 0; i < kNumThreads; ++i)
84-
sum += num_tokens_per_expert_per_thread[i][thread_id];
85-
num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
86-
}
87-
return;
88-
}
89-
90-
if (num_tokens_per_rdma_rank != nullptr)
91-
EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 &&
92-
num_ranks > NUM_MAX_NVL_PEERS);
93-
94-
// Count rank statistics
95-
constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
96-
__shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM];
97-
__shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads]
98-
[kNumRDMARanksPerSM];
99-
auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM;
100-
int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM,
101-
rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks);
102-
int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS,
103-
rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
104-
if (rank_begin_idx < rank_end_idx) {
105-
const auto num_expert_per_rank = num_experts / num_ranks;
106-
auto expert_begin = rank_begin_idx * num_expert_per_rank;
107-
auto expert_end = rank_end_idx * num_expert_per_rank;
108-
109-
// Per-thread count
110-
#pragma unroll
111-
for (int i = 0; i < kNumRanksPerSM; ++i)
112-
num_tokens_per_rank_per_thread[thread_id][i] = 0;
113-
#pragma unroll
114-
for (int i = 0; i < kNumRDMARanksPerSM; ++i)
115-
num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0;
116-
#pragma unroll
117-
for (int i = thread_id; i < num_tokens; i += kNumThreads) {
118-
auto shifted_topk_idx = topk_idx + i * num_topk;
119-
int is_in_rank[kNumRanksPerSM] = {0},
120-
is_in_rdma_rank[kNumRDMARanksPerSM] = {0};
121-
#pragma unroll
122-
for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) {
123-
expert_idx = static_cast<int>(shifted_topk_idx[j]);
124-
if (expert_begin <= expert_idx && expert_idx < expert_end) {
125-
// Count single rank
126-
rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
127-
is_in_rank[rank_idx]++,
128-
is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++;
129-
}
130-
}
131-
132-
auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
133-
#pragma unroll
134-
for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) {
135-
shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0);
136-
num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0);
137-
}
138-
139-
#pragma unroll
140-
for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j)
141-
num_tokens_per_rdma_rank_per_thread[thread_id][j] +=
142-
(is_in_rdma_rank[j] > 0);
143-
}
144-
__syncthreads();
145-
146-
// Sum up
147-
EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM");
148-
if (rank_begin_idx + thread_id < rank_end_idx) {
149-
int sum = 0;
150-
#pragma unroll
151-
for (int i = 0; i < kNumThreads; ++i)
152-
sum += num_tokens_per_rank_per_thread[i][thread_id];
153-
num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
154-
}
155-
156-
if (num_tokens_per_rdma_rank != nullptr &&
157-
rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
158-
int sum = 0;
159-
#pragma unroll
160-
for (int i = 0; i < kNumThreads; ++i)
161-
sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
162-
num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
163-
}
164-
}
165-
}
166-
167-
void get_dispatch_layout(const int64_t* topk_idx,
168-
int* num_tokens_per_rank,
169-
int* num_tokens_per_rdma_rank,
170-
int* num_tokens_per_expert,
171-
bool* is_token_in_rank,
172-
int num_tokens,
173-
int num_topk,
174-
int num_ranks,
175-
int num_experts,
176-
cudaStream_t stream) {
177-
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
178-
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) +
179-
(num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
180-
EP_STATIC_ASSERT(kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0,
181-
"Invalid number of experts per SM");
182-
183-
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
184-
LAUNCH_KERNEL(
185-
&cfg,
186-
(get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
187-
topk_idx,
188-
num_tokens_per_rank,
189-
num_tokens_per_rdma_rank,
190-
num_tokens_per_expert,
191-
is_token_in_rank,
192-
num_tokens,
193-
num_topk,
194-
num_ranks,
195-
num_experts);
196-
}
197-
19840
struct SourceMeta {
19941
int src_rdma_rank, is_token_in_nvl_rank_bits;
20042

0 commit comments

Comments
 (0)