From 2bf764cfd442b2423a13d423bd72135a03187896 Mon Sep 17 00:00:00 2001 From: shifangx Date: Thu, 24 Jul 2025 22:14:17 -0700 Subject: [PATCH 01/12] support NVFP4 data format in low latency dispatch --- csrc/deep_ep.cpp | 26 +++- csrc/deep_ep.hpp | 4 +- csrc/kernels/api.cuh | 4 +- csrc/kernels/internode_ll.cu | 237 ++++++++++++++++++++++++++++++++--- deep_ep/buffer.py | 19 ++- tests/test_low_latency.py | 52 ++++++-- tests/utils.py | 62 ++++++++- 7 files changed, 367 insertions(+), 37 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 0789cd58..67393fbc 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1087,12 +1087,14 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int #endif } -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> +std::tuple, std::optional, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + const std::optional& x_sf_scale, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf, bool async, bool return_recv_hook) { #ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); @@ -1137,8 +1139,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); + constexpr int NUM_ELEMS_PER_PACK = 8; + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, use_nvfp4 ? hidden / NUM_ELEMS_PER_PACK : hidden}, + x.options().dtype(use_nvfp4 ? torch::kInt32 : (use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16))); auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); @@ -1146,6 +1149,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); void* packed_recv_x_scales_ptr = nullptr; + auto packed_recv_x_sf_scale = std::optional(); + void* packed_recv_x_sf_scale_ptr = nullptr; EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); if (use_fp8) { @@ -1161,16 +1166,26 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + }else if (use_nvfp4) { + constexpr int SF_VEC_SIZE = 16; + constexpr int NUM_SF_ELEMS_PER_PACK = 4; + packed_recv_x_scales = torch::empty({num_local_experts, hidden / (SF_VEC_SIZE * NUM_SF_ELEMS_PER_PACK), num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + packed_recv_x_sf_scale = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + packed_recv_x_sf_scale_ptr = packed_recv_x_sf_scale->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { - internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, + internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_x_sf_scale_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), packed_recv_count.data_ptr(), cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, + x_sf_scale.has_value() ? x_sf_scale->data_ptr() : nullptr, buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, buffer.dispatch_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), @@ -1178,6 +1193,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i num_tokens, hidden, num_max_dispatch_tokens_per_rank, num_topk, num_experts, rank, num_ranks, use_fp8, round_scale, use_ue8m0, + use_nvfp4, use_ue8m0_for_nvfp4_sf, workspace, num_device_sms, launch_stream, phases); }; @@ -1199,7 +1215,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i recv_hook = [=]() { launcher(LOW_LATENCY_RECV_PHASE); }; // Return values - return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; + return {packed_recv_x, packed_recv_x_scales, packed_recv_x_sf_scale, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; #else EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); return {}; diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index aa62ccb0..27ff4951 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -143,12 +143,14 @@ struct Buffer { void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> + std::tuple, std::optional, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, const std::optional& cumulative_local_expert_recv_stats, const std::optional& dispatch_wait_recv_cost_stats, + const std::optional& x_sf_scale, int num_max_dispatch_tokens_per_rank, int num_experts, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf, bool async, bool return_recv_hook); std::tuple, std::optional>> diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index d34775fd..6540443c 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -139,17 +139,19 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1, cudaStream_t stream); -void dispatch(void* packed_recv_x, void* packed_recv_x_scales, +void dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf_scale, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_sf_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf, void* workspace, int num_device_sms, cudaStream_t stream, int phases); diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 391a4b3d..f13880f7 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -36,13 +36,158 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +constexpr int CVT_ELTS_PER_THREAD = 8; +constexpr int SF_VEC_SIZE = 16; + +struct PackedVec { + __nv_bfloat162 elts[4]; +}; + +using Type = __nv_bfloat16; + +__device__ __forceinline__ float exp2f_rcp(uint8_t exp) { + constexpr uint32_t FP32_EXPONENT_BIAS = 127; + return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(exp)); +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x), + "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; + #else + // static_assert(false, "not supported."); + return 0; + #endif +} + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]), + "f"(array[6]), "f"(array[7])); + return val; + #else + // static_assert(false, "not supported."); + return 0; + #endif +} + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_ELTS_PER_THREAD; + EP_STATIC_ASSERT(CVT_NUM_THREADS_PER_SF == 2 or CVT_NUM_THREADS_PER_SF == 4, "Invalid number of threads per SF"); + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + if constexpr (CVT_NUM_THREADS_PER_SF == 4) { + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // 8 bits representation of the SF. + uint8_t fp8SFVal; + float outputScale; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + __nv_fp8_e8m0 tmp; + // Scale the max value to the range of E2m1. + vecMax *= reciprocal_approximate_ftz(6.0f); + tmp.__x = __nv_cvt_float_to_e8m0(vecMax, __NV_SATFINITE, cudaRoundPosInf); + fp8SFVal = tmp.__x; + outputScale = exp2f_rcp(fp8SFVal); + } else { + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + auto SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + fp8SFVal = tmp.__x; + SFValue = static_cast(tmp); + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal)) * reciprocal(SFScaleVal)) + outputScale = SFValue != 0 + ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + } + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +} + +template __global__ __launch_bounds__(1024, 1) void -dispatch(void* packed_recv_x, void* packed_recv_x_scales, +dispatch(void* packed_recv_x, void* packed_recv_x_scales, void* packed_recv_x_sf_scale, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_sf_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* atomic_counter_per_expert, int* atomic_finish_counter_per_expert, @@ -62,20 +207,28 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto responsible_expert_idx = sm_id * num_warp_groups + warp_group_id; // May extract UE8M0 from the scales - using scale_t = std::conditional_t; - using packed_t = std::conditional_t; + using scale_t = std::conditional_t; + using packed_t = std::conditional_t; EP_STATIC_ASSERT(sizeof(packed_t) % sizeof(scale_t) == 0, "Invalid vector length"); + EP_STATIC_ASSERT(!(kUseFP8 && kUseNVFP4), "FP8 and NVFP4 cannot be used together"); // FP8 staffs - constexpr int kNumPerChannels = 128; + constexpr int kNumPerChannels = kUseNVFP4 ? 16 : 128; const int num_scales = kHidden / kNumPerChannels; - const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); + constexpr size_t hidden_bytes = + kUseNVFP4 + ? kHidden * sizeof(__nv_fp8_storage_t) / 2 + : kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16)); const size_t hidden_int4 = hidden_bytes / sizeof(int4); // Message package: index at source (int), 3 reserved int fields, hidden data, FP8 scales // NOTES: currently we have 3 reserved int fields for future use - using vec_t = std::conditional_t; - const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16))); + using vec_t = std::conditional_t< + kUseNVFP4, + int32_t, + std::conditional_t>; + using rdma_x_scale_t = std::conditional_t; + const size_t num_bytes_per_msg = sizeof(int4) + ((kUseFP8 || kUseNVFP4) ? (hidden_bytes + num_scales * sizeof(rdma_x_scale_t)) : hidden_bytes); const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); @@ -100,12 +253,24 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { const auto x_int4 = static_cast(x) + token_idx * hidden_bf16_int4; const auto rdma_x_src_idx = reinterpret_cast(static_cast(rdma_x) + token_idx * num_bytes_per_msg); + const auto rdma_x_sf_scale = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int)); const auto rdma_x_vec = reinterpret_cast(reinterpret_cast(rdma_x_src_idx) + sizeof(int4)); - const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); + const auto rdma_x_scales = reinterpret_cast(reinterpret_cast(rdma_x_vec) + hidden_bytes); // Overlap top-k index read and source token index writes auto dst_expert_idx = warp_id < num_topk ? static_cast(__ldg(topk_idx + token_idx * num_topk + warp_id)) : -1; thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0; + float SFScaleVal = 1.0f; + if constexpr (kUseNVFP4) { + // Get scaling value: if x_sf_scale is nullptr, use 1.0f; otherwise, read value at token_idx + if (x_sf_scale != nullptr) { + SFScaleVal = *(static_cast(x_sf_scale) + token_idx); + } + // Only thread 0 writes scaling value to rdma_x_sf_scale + if (thread_id == 0) { + *rdma_x_sf_scale = SFScaleVal; + } + } // FP8 cast EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce"); @@ -141,6 +306,20 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); } rdma_x_vec[i] = int2_value; + } else if constexpr (kUseNVFP4) { + // Convert to NVFP4 + uint8_t sf_val; + PackedVec vec = *reinterpret_cast(&int4_value); + uint32_t result = cvt_warp_fp16_to_fp4(vec, SFScaleVal, &sf_val); + + // Write scale to send buffer + if (lane_id % 2 == 0){ + EP_DEVICE_ASSERT((i * kNumElemsPerRead) % 16 == 0); + int rdma_x_scale_idx = i * kNumElemsPerRead / 16; + rdma_x_scales[rdma_x_scale_idx] = sf_val; + } + // Cast into send buffer + rdma_x_vec[i] = *reinterpret_cast(&result); } else { // Reinterpret-cast is for C++14 compatibility rdma_x_vec[i] = *reinterpret_cast(&int4_value); @@ -262,6 +441,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, const auto recv_x_int4 = static_cast(packed_recv_x) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * hidden_int4; const auto recv_src_info = packed_recv_src_info + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; + const auto recv_sf_scale = static_cast(packed_recv_x_sf_scale) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank; const auto recv_range = packed_recv_layout_range + local_expert_idx * num_ranks; const auto num_aligned_scales = align(num_scales, sizeof(float) / sizeof(scale_t)); const auto recv_x_scales = static_cast(packed_recv_x_scales) + local_expert_idx * num_ranks * num_max_dispatch_tokens_per_rank * num_aligned_scales; @@ -294,12 +474,17 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, recv_token_begin_idx = shared_recv_token_begin_idx[warp_group_id]; // Copy tokens - EP_DEVICE_ASSERT(num_scales <= 64); for (int i = sub_warp_id; i < num_recv_tokens; i += num_warps_per_group) { // Copy source info const auto src_src_idx = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg); if (lane_id == 0) recv_src_info[recv_token_begin_idx + i] = ld_nc_global(src_src_idx); + if constexpr (kUseNVFP4) { + const auto src_sf_scale_for_nvfp4 = reinterpret_cast(rdma_recv_x_uint8 + i * num_bytes_per_msg + sizeof(int)); + if (lane_id == 0) + recv_sf_scale[recv_token_begin_idx + i] = ld_nc_global(src_sf_scale_for_nvfp4); + } + __syncwarp(); // Copy data @@ -310,6 +495,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, // Copy scales if constexpr (kUseFP8) { + EP_DEVICE_ASSERT(num_scales <= 64); // Equivalent CuTe layout: // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); @@ -329,22 +515,40 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales, auto scale = extract_required_scale_format(ld_nc_global(src_scales + lane_id + 32)); recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; } + } else if constexpr (kUseNVFP4) { + // Equivalent CuTe layout: + // (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1)) + const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); + const auto num_elems_per_pack = static_cast(sizeof(packed_t) / sizeof(scale_t)); + const auto token_idx = recv_token_begin_idx + i; + const auto token_stride = num_elems_per_pack; + const auto pack_stride = num_ranks * num_max_dispatch_tokens_per_rank * num_elems_per_pack; + #pragma unroll + for (int j = lane_id; j < num_scales; j += 32) { + const auto pack_idx = j / num_elems_per_pack; + const auto elem_idx = j % num_elems_per_pack; + auto scale = ld_nc_global(src_scales + j); + recv_x_scales[token_idx * token_stride + pack_idx * pack_stride + elem_idx] = scale; + } } } } } void dispatch(void* packed_recv_x, void* packed_recv_x_scales, + void* packed_recv_x_sf_scale, int* packed_recv_src_info, int64_t* packed_recv_layout_range, int* packed_recv_count, int* cumulative_local_expert_recv_stats, int64_t* dispatch_wait_recv_cost_stats, + const float* x_sf_scale, void* rdma_recv_x, int* rdma_recv_count, void* rdma_x, const void* x, const int64_t* topk_idx, int* next_clean, int num_next_clean_int, int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, bool round_scale, bool use_ue8m0, + bool use_nvfp4, bool use_ue8m0_for_nvfp4_sf, void* workspace, int num_device_sms, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; @@ -367,17 +571,22 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales, EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`"); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = dispatch