Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warp_groups = ((phases & LOW_LATENCY_RECV_PHASE) == 0)
? 9
: ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);
EP_HOST_ASSERT(kNumMaxTopK + 1 <= num_warp_groups * num_warps_per_group);
Expand Down Expand Up @@ -548,7 +550,9 @@ void combine(void* combined_x,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases, bool zero_copy) {
constexpr int kNumMaxTopk = 9;
const int num_warp_groups = ceil_div(num_experts, num_device_sms);
const int num_warp_groups = ((phases & LOW_LATENCY_RECV_PHASE) == 0)
? 9
: ceil_div(num_experts, num_device_sms);
const int num_warps_per_group = 32 / num_warp_groups;
EP_HOST_ASSERT(num_warp_groups > 0 and num_warps_per_group > 0);

Expand Down