Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions csrc/deep_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,8 +1092,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook) {
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook,
bool use_per_tensor_quantization,
const std::optional<torch::Tensor>& static_scale) {
#ifndef DISABLE_NVSHMEM
EP_HOST_ASSERT(low_latency_mode);

Expand Down Expand Up @@ -1148,7 +1150,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
void* packed_recv_x_scales_ptr = nullptr;
EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4");

if (use_fp8) {
if (use_fp8 and not use_per_tensor_quantization) {
// TODO: support unaligned cases
EP_HOST_ASSERT(hidden % 512 == 0);
if (not use_ue8m0) {
Expand All @@ -1163,6 +1165,19 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr();
}

// 检查静态量化参数
if (static_scale.has_value()) {
EP_HOST_ASSERT(use_fp8 && "Static scale requires FP8 quantization");
auto scale_tensor = static_scale.value();
EP_HOST_ASSERT(scale_tensor.is_contiguous());
EP_HOST_ASSERT(scale_tensor.scalar_type() == torch::kFloat32);
if (use_per_tensor_quantization) {
EP_HOST_ASSERT(scale_tensor.numel() == 1);
} else {
EP_HOST_ASSERT(scale_tensor.numel() == hidden / 128);
}
}

// Kernel launch
auto next_clean_meta = next_buffer.clean_meta();
auto launcher = [=](int phases) {
Expand All @@ -1177,7 +1192,8 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
next_clean_meta.first, next_clean_meta.second,
num_tokens, hidden, num_max_dispatch_tokens_per_rank,
num_topk, num_experts, rank, num_ranks,
use_fp8, round_scale, use_ue8m0,
use_fp8, round_scale, use_ue8m0, use_per_tensor_quantization,
static_scale.has_value() ? static_scale->data_ptr<float>() : nullptr, // 传递静态量化参数
workspace, num_device_sms,
launch_stream, phases);
};
Expand Down
4 changes: 3 additions & 1 deletion csrc/deep_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ struct Buffer {
const std::optional<torch::Tensor>& dispatch_wait_recv_cost_stats,
int num_max_dispatch_tokens_per_rank, int num_experts,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool async, bool return_recv_hook);
bool async, bool return_recv_hook,
bool use_per_tensor_quantization,
const std::optional<torch::Tensor>& static_scale);

std::tuple<torch::Tensor, std::optional<EventHandle>, std::optional<std::function<void()>>>
low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights,
Expand Down
14 changes: 14 additions & 0 deletions csrc/kernels/api.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
int* packed_recv_count,
int* cumulative_local_expert_recv_stats,
int64_t* dispatch_wait_recv_cost_stats,
void* rdma_recv_x, int* rdma_recv_count, void* rdma_x,
const void* x, const int64_t* topk_idx,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0, bool use_per_tensor_quantization,const float* static_scale,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases);

void combine(void* combined_x,
void* rdma_recv_x, int* rdma_recv_flag, void* rdma_send_x,
const void* x, const int64_t* topk_idx, const float* topk_weights,
Expand Down
71 changes: 42 additions & 29 deletions csrc/kernels/internode_ll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void clean_low_latency_buffer(int* clean_0, int num_clean_int_0,
clean_0, num_clean_int_0, clean_1, num_clean_int_1);
}

template <bool kUseFP8, bool kUseUE8M0, int kHidden>
template <bool kUseFP8, bool kUseUE8M0, bool kUsePerTensorStaticQuantization, int kHidden>
__global__ __launch_bounds__(1024, 1) void
dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* packed_recv_src_info, int64_t* packed_recv_layout_range,
Expand All @@ -50,7 +50,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int num_tokens, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
int num_warp_groups, int num_warps_per_group,
bool round_scale, int phases) {
bool round_scale, const float* static_scale, int phases) {
const auto sm_id = static_cast<int>(blockIdx.x);
const auto thread_id = static_cast<int>(threadIdx.x);
const auto warp_id = thread_id / 32, lane_id = get_lane_id();
Expand All @@ -68,14 +68,15 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,

// FP8 staffs
constexpr int kNumPerChannels = 128;
const int num_scales = kHidden / kNumPerChannels;
const int num_scales = kUsePerTensorStaticQuantization ? 1 : kHidden / kNumPerChannels;
const size_t hidden_bytes = kHidden * (kUseFP8 ? sizeof(__nv_fp8_storage_t) : sizeof(nv_bfloat16));
const size_t hidden_int4 = hidden_bytes / sizeof(int4);

// Message package: hidden data, FP8 scales, index at source
// NOTES: currently we have 3 reserved int fields for future use
using vec_t = std::conditional_t<kUseFP8, int2, int4>;
const size_t num_bytes_per_msg = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
const size_t base_bytes = sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float)) : (kHidden * sizeof(nv_bfloat16)));
const size_t num_bytes_per_msg = (base_bytes + sizeof(int4) - 1) / sizeof(int4) * sizeof(int4); // 对齐到16字节边界
const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4);
EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0);

Expand Down Expand Up @@ -108,31 +109,39 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
thread_id == 0 ? (*rdma_x_src_idx = token_idx) : 0;

// FP8 cast
EP_STATIC_ASSERT(hidden_bf16_int4 % 32 == 0, "Must use the full warp to reduce");
#pragma unroll
for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) {
// Read
auto int4_value = __ldg(x_int4 + i);

if constexpr (kUseFP8) {
// Calculate local amax
float scale, scale_inv;
auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value);
float fp32_values[kNumElemsPerRead];
float amax = kFP8Margin, scale, scale_inv;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}

// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
if constexpr (kUsePerTensorStaticQuantization) {
EP_DEVICE_ASSERT(static_scale != nullptr);
scale_inv = static_scale[0];
scale = 1.0f / scale_inv;
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
}

// Cast into send buffer
} else {
float amax = kFP8Margin;
#pragma unroll
for (int j = 0; j < kNumElemsPerRead; ++ j) {
fp32_values[j] = static_cast<float>(bf16_values[j]);
amax = fmaxf(amax, fabsf(fp32_values[j]));
}
// Reduce amax and scale
EP_STATIC_ASSERT(kNumElemsPerRead * 32 / kNumPerChannels == 2, "Invalid vectorization");
amax = warp_reduce_max<16>(amax);
calculate_fp8_scales(amax, scale, scale_inv, round_scale);
if (lane_id == 0 or lane_id == 16)
rdma_x_scales[i * kNumElemsPerRead / 128] = scale_inv;
}

// Cast into send buffer using per-tensor quantization logic
vec_t int2_value;
auto fp8x2_values = reinterpret_cast<__nv_fp8x2_storage_t*>(&int2_value);
#pragma unroll
Expand Down Expand Up @@ -309,7 +318,7 @@ dispatch(void* packed_recv_x, void* packed_recv_x_scales,
UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global);

// Copy scales
if constexpr (kUseFP8) {
if constexpr (kUseFP8 and not kUsePerTensorStaticQuantization) {
// Equivalent CuTe layout:
// (num_tokens, (num_packed, num_elems_per_pack)):(num_elems_per_pack, (num_tokens * num_elems_per_pack, 1))
const auto src_scales = reinterpret_cast<float*>(reinterpret_cast<uint8_t*>(src_data) + hidden_bytes);
Expand Down Expand Up @@ -344,7 +353,7 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
int* next_clean, int num_next_clean_int,
int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
int num_topk, int num_experts, int rank, int num_ranks,
bool use_fp8, bool round_scale, bool use_ue8m0,
bool use_fp8, bool round_scale, bool use_ue8m0,bool use_per_tensor_quantization,const float* static_scale,
void* workspace, int num_device_sms,
cudaStream_t stream, int phases) {
constexpr int kNumMaxTopK = 9;
Expand All @@ -367,11 +376,15 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
EP_HOST_ASSERT(round_scale and "UE8M0 SF requires `round_scale=True`");

#define DISPATCH_LAUNCH_CASE(hidden) { \
auto dispatch_func = dispatch<false, false, hidden>; \
if (use_fp8 and not use_ue8m0) \
dispatch_func = dispatch<true, false, hidden>; \
if (use_fp8 and use_ue8m0) \
dispatch_func = dispatch<true, true, hidden>; \
auto dispatch_func = dispatch<false, false,false, hidden>; \
if (use_fp8 and not use_ue8m0 and not use_per_tensor_quantization) \
dispatch_func = dispatch<true, false, false,hidden>; \
if (use_fp8 and use_ue8m0 and not use_per_tensor_quantization) \
dispatch_func = dispatch<true, true, false,hidden>; \
if (use_fp8 and use_ue8m0 and not use_per_tensor_quantization) \
dispatch_func = dispatch<true, true, false,hidden>; \
if (use_fp8 and not use_ue8m0 and use_per_tensor_quantization) \
dispatch_func = dispatch<true, false, true,hidden>; \
LAUNCH_KERNEL(&cfg, dispatch_func, \
packed_recv_x, packed_recv_x_scales, \
packed_recv_src_info, packed_recv_layout_range, \
Expand All @@ -385,7 +398,7 @@ LAUNCH_KERNEL(&cfg, dispatch_func, \
num_tokens, num_max_dispatch_tokens_per_rank, \
num_topk, num_experts, rank, num_ranks, \
num_warp_groups, num_warps_per_group, \
round_scale, phases); } break
round_scale,static_scale, phases); } break

SETUP_LAUNCH_CONFIG(num_sms, num_warps * 32, stream);
SWITCH_HIDDEN(DISPATCH_LAUNCH_CASE);
Expand Down
9 changes: 7 additions & 2 deletions deep_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,9 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None,
dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None,
use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False,
async_finish: bool = False, return_recv_hook: bool = False) -> \
async_finish: bool = False, return_recv_hook: bool = False,
use_per_tensor_quantization: bool = False,
static_scale: Optional[torch.Tensor] = None) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]:
"""
A low-latency implementation for dispatching with IBGDA.
Expand Down Expand Up @@ -558,6 +560,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues,
but **without actually receiving the data**. You must call the received hook to make sure the data's arrival.
If you do not set this flag, the kernel will ensure the data's arrival.
use_per_tensor_quantization: whether use per tensor quantization
static_scale: Optional[torch.Tensor]: per tensor quantization scale

Returns:
recv_x: a tensor or tuple with received tokens for each expert.
Expand All @@ -584,7 +588,8 @@ def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor,
dispatch_wait_recv_cost_stats,
num_max_dispatch_tokens_per_rank, num_experts,
use_fp8, round_scale, use_ue8m0,
async_finish, return_recv_hook)
async_finish, return_recv_hook,use_per_tensor_quantization,
static_scale)
handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts)
tensors_to_record = (x, topk_idx,
packed_recv_x, packed_recv_x_scales, packed_recv_count,
Expand Down