@@ -37,164 +37,6 @@ namespace internode {
37
37
38
38
extern nvshmem_team_t cpu_rdma_team;
39
39
40
- template <int kNumThreads , int kNumExpertsPerSM , int kNumRanksPerSM >
41
- __global__ void __launch_bounds__ (kNumThreads , 1 )
42
- get_dispatch_layout(const int64_t * topk_idx,
43
- int * num_tokens_per_rank,
44
- int * num_tokens_per_rdma_rank,
45
- int * num_tokens_per_expert,
46
- bool * is_token_in_rank,
47
- int num_tokens,
48
- int num_topk,
49
- int num_ranks,
50
- int num_experts) {
51
- auto sm_id = static_cast <int >(blockIdx .x );
52
- auto thread_id = static_cast <int >(threadIdx .x );
53
-
54
- // Count expert statistics
55
- __shared__ int num_tokens_per_expert_per_thread[kNumThreads ]
56
- [kNumExpertsPerSM ];
57
- int expert_begin_idx = sm_id * kNumExpertsPerSM ,
58
- expert_end_idx = min (expert_begin_idx + kNumExpertsPerSM , num_experts);
59
- if (expert_begin_idx < expert_end_idx) {
60
- // Per-thread count
61
- #pragma unroll
62
- for (int i = 0 ; i < kNumExpertsPerSM ; ++i)
63
- num_tokens_per_expert_per_thread[thread_id][i] = 0 ;
64
- #pragma unroll
65
- for (int i = thread_id; i < num_tokens; i += kNumThreads ) {
66
- auto shifted_topk_idx = topk_idx + i * num_topk;
67
- #pragma unroll
68
- for (int j = 0 , expert_idx; j < num_topk; ++j) {
69
- expert_idx = static_cast <int >(shifted_topk_idx[j]);
70
- if (expert_begin_idx <= expert_idx && expert_idx < expert_end_idx)
71
- ++num_tokens_per_expert_per_thread[thread_id]
72
- [expert_idx - expert_begin_idx];
73
- }
74
- }
75
- __syncthreads ();
76
-
77
- // Sum up
78
- EP_STATIC_ASSERT (kNumExpertsPerSM <= kNumThreads ,
79
- " Too many experts per SM" );
80
- if (expert_begin_idx + thread_id < expert_end_idx) {
81
- int sum = 0 ;
82
- #pragma unroll
83
- for (int i = 0 ; i < kNumThreads ; ++i)
84
- sum += num_tokens_per_expert_per_thread[i][thread_id];
85
- num_tokens_per_expert[expert_begin_idx + thread_id] = sum;
86
- }
87
- return ;
88
- }
89
-
90
- if (num_tokens_per_rdma_rank != nullptr )
91
- EP_DEVICE_ASSERT (num_ranks % NUM_MAX_NVL_PEERS == 0 &&
92
- num_ranks > NUM_MAX_NVL_PEERS);
93
-
94
- // Count rank statistics
95
- constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS;
96
- __shared__ int num_tokens_per_rank_per_thread[kNumThreads ][kNumRanksPerSM ];
97
- __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads ]
98
- [kNumRDMARanksPerSM ];
99
- auto sm_begin = (num_experts + kNumExpertsPerSM - 1 ) / kNumExpertsPerSM ;
100
- int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM ,
101
- rank_end_idx = min (rank_begin_idx + kNumRanksPerSM , num_ranks);
102
- int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS,
103
- rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS;
104
- if (rank_begin_idx < rank_end_idx) {
105
- const auto num_expert_per_rank = num_experts / num_ranks;
106
- auto expert_begin = rank_begin_idx * num_expert_per_rank;
107
- auto expert_end = rank_end_idx * num_expert_per_rank;
108
-
109
- // Per-thread count
110
- #pragma unroll
111
- for (int i = 0 ; i < kNumRanksPerSM ; ++i)
112
- num_tokens_per_rank_per_thread[thread_id][i] = 0 ;
113
- #pragma unroll
114
- for (int i = 0 ; i < kNumRDMARanksPerSM ; ++i)
115
- num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0 ;
116
- #pragma unroll
117
- for (int i = thread_id; i < num_tokens; i += kNumThreads ) {
118
- auto shifted_topk_idx = topk_idx + i * num_topk;
119
- int is_in_rank[kNumRanksPerSM ] = {0 },
120
- is_in_rdma_rank[kNumRDMARanksPerSM ] = {0 };
121
- #pragma unroll
122
- for (int j = 0 , expert_idx, rank_idx; j < num_topk; ++j) {
123
- expert_idx = static_cast <int >(shifted_topk_idx[j]);
124
- if (expert_begin <= expert_idx && expert_idx < expert_end) {
125
- // Count single rank
126
- rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx;
127
- is_in_rank[rank_idx]++,
128
- is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++;
129
- }
130
- }
131
-
132
- auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks;
133
- #pragma unroll
134
- for (int j = 0 ; j + rank_begin_idx < rank_end_idx; ++j) {
135
- shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0 );
136
- num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0 );
137
- }
138
-
139
- #pragma unroll
140
- for (int j = 0 ; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j)
141
- num_tokens_per_rdma_rank_per_thread[thread_id][j] +=
142
- (is_in_rdma_rank[j] > 0 );
143
- }
144
- __syncthreads ();
145
-
146
- // Sum up
147
- EP_STATIC_ASSERT (kNumRanksPerSM <= kNumThreads , " Too many ranks per SM" );
148
- if (rank_begin_idx + thread_id < rank_end_idx) {
149
- int sum = 0 ;
150
- #pragma unroll
151
- for (int i = 0 ; i < kNumThreads ; ++i)
152
- sum += num_tokens_per_rank_per_thread[i][thread_id];
153
- num_tokens_per_rank[rank_begin_idx + thread_id] = sum;
154
- }
155
-
156
- if (num_tokens_per_rdma_rank != nullptr &&
157
- rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) {
158
- int sum = 0 ;
159
- #pragma unroll
160
- for (int i = 0 ; i < kNumThreads ; ++i)
161
- sum += num_tokens_per_rdma_rank_per_thread[i][thread_id];
162
- num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum;
163
- }
164
- }
165
- }
166
-
167
- void get_dispatch_layout (const int64_t * topk_idx,
168
- int * num_tokens_per_rank,
169
- int * num_tokens_per_rdma_rank,
170
- int * num_tokens_per_expert,
171
- bool * is_token_in_rank,
172
- int num_tokens,
173
- int num_topk,
174
- int num_ranks,
175
- int num_experts,
176
- cudaStream_t stream) {
177
- constexpr int kNumThreads = 256 , kNumExpertsPerSM = 32 , kNumRanksPerSM = 8 ;
178
- int num_sms = ((num_experts + kNumExpertsPerSM - 1 ) / kNumExpertsPerSM ) +
179
- (num_ranks + kNumRanksPerSM - 1 ) / kNumRanksPerSM ;
180
- EP_STATIC_ASSERT (kNumExpertsPerSM % NUM_MAX_NVL_PEERS == 0 ,
181
- " Invalid number of experts per SM" );
182
-
183
- SETUP_LAUNCH_CONFIG (num_sms, kNumThreads , stream);
184
- LAUNCH_KERNEL (
185
- &cfg,
186
- (get_dispatch_layout<kNumThreads , kNumExpertsPerSM , kNumRanksPerSM >),
187
- topk_idx,
188
- num_tokens_per_rank,
189
- num_tokens_per_rdma_rank,
190
- num_tokens_per_expert,
191
- is_token_in_rank,
192
- num_tokens,
193
- num_topk,
194
- num_ranks,
195
- num_experts);
196
- }
197
-
198
40
struct SourceMeta {
199
41
int src_rdma_rank, is_token_in_nvl_rank_bits;
200
42
0 commit comments