Skip to content

Commit 076d3ff

Browse files
committed
Merge branch 'develop' into sot/add-more-virtual-destructor-for-virtual-class
2 parents 00b6601 + 6d1a7c0 commit 076d3ff

File tree

17 files changed

+233
-81
lines changed

17 files changed

+233
-81
lines changed

cmake/external/xpu.cmake

+7-7
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ if(NOT DEFINED XPU_XHPC_BASE_DATE)
3434
endif()
3535
set(XPU_XCCL_BASE_VERSION "3.0.2.5") # For XRE5
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)
37-
set(XPU_XFT_BASE_VERSION "20230602")
37+
set(XPU_XFT_BASE_VERSION "20250402/xpu3")
3838
endif()
3939

4040
if(NOT DEFINED XPU_XRE_BASE_VERSION)
4141
if(WITH_XPU_XRE5)
42-
set(XPU_XRE_BASE_VERSION "5.0.21.18")
42+
set(XPU_XRE_BASE_VERSION "5.0.21.19")
4343
else()
4444
set(XPU_XRE_BASE_VERSION "4.32.0.1")
4545
endif()
@@ -61,7 +61,7 @@ set(XPU_XCCL_BASE_URL
6161

6262
if(NOT XPU_XFT_BASE_URL)
6363
set(XPU_XFT_BASE_URL
64-
"https://klx-sdk-release-public.su.bcebos.com/xft/dev/${XPU_XFT_BASE_VERSION}"
64+
"https://klx-sdk-release-public.su.bcebos.com/xft_internal/dev/${XPU_XFT_BASE_VERSION}"
6565
)
6666
endif()
6767

@@ -112,7 +112,7 @@ else()
112112
set(XPU_XHPC_DIR_NAME "xhpc-ubuntu1604_x86_64")
113113
endif()
114114
set(XPU_XCCL_DIR_NAME "xccl_Linux_x86_64")
115-
set(XPU_XFT_DIR_NAME "xft_ubuntu1604_x86_64")
115+
set(XPU_XFT_DIR_NAME "xft_internal_ubuntu2004")
116116
endif()
117117

118118
set(XPU_XRE_URL
@@ -187,9 +187,9 @@ if(DEFINED ENV{XPU_LIB_ROOT})
187187
endif()
188188

189189
# XCCL
190-
if(DEFINED ENV{XCCL_DIR_NAME})
191-
set(XPU_XCCL_URL "${XPU_LIB_ROOT}/$ENV{XCCL_DIR_NAME}")
192-
set(XCCL_DIR_NAME "$ENV{XCCL_DIR_NAME}")
190+
if(DEFINED ENV{XPU_XCCL_DIR_NAME})
191+
set(XPU_XCCL_URL "${XPU_LIB_ROOT}/$ENV{XPU_XCCL_DIR_NAME}")
192+
set(XPU_XCCL_DIR_NAME "$ENV{XPU_XCCL_DIR_NAME}")
193193
endif()
194194

195195
# XHPC

paddle/fluid/distributed/collective/deep_ep/deep_ep.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,8 @@ void Buffer::clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
16251625
#endif
16261626
}
16271627

1628+
void Buffer::barrier_all() { internode_ll::barrier_all(calc_ctx->stream()); }
1629+
16281630
#ifdef PADDLE_WITH_NVSHMEM
16291631
std::tuple<deep_ep::detail::Tensor,
16301632
std::optional<deep_ep::detail::Tensor>,

paddle/fluid/distributed/collective/deep_ep/deep_ep.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ struct Buffer {
251251
void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank,
252252
int hidden,
253253
int num_experts);
254+
void barrier_all();
254255

255256
#ifdef PADDLE_WITH_NVSHMEM
256257
std::tuple<deep_ep::detail::Tensor,

paddle/fluid/distributed/collective/deep_ep/kernels/api.cuh

+1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ void combine(cudaDataType_t type,
288288
// Internode low-latency kernels
289289
namespace internode_ll {
290290

291+
void barrier_all(cudaStream_t stream);
291292
void clean_low_latency_buffer(int* clean_0,
292293
int num_clean_int_0,
293294
int* clean_1,

paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh

+63
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ uint16_t HtoBE16(uint16_t x) {
8282

8383
typedef struct mlx5_wqe_ctrl_seg __attribute__((__aligned__(8))) ibgda_ctrl_seg_t;
8484

85+
typedef struct {
86+
uint32_t add_data;
87+
uint32_t field_boundary;
88+
uint64_t reserved;
89+
} __attribute__((__packed__)) ibgda_atomic_32_masked_fa_seg_t;
90+
8591
__device__ static __forceinline__
8692
nvshmemi_ibgda_device_state_t* ibgda_get_state() {
8793
return &nvshmemi_ibgda_device_state_d;
@@ -439,4 +445,61 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
439445
__syncwarp();
440446
}
441447

448+
__device__ static __forceinline__ void ibgda_write_amo_add_wqe(
449+
nvshmemi_ibgda_device_qp_t *qp, const int &value,
450+
uint64_t laddr, __be32 lkey, uint64_t raddr, __be32 rkey,
451+
uint16_t wqe_idx, void **out_wqes) {
452+
ibgda_ctrl_seg_t ctrl_seg = {0};
453+
struct mlx5_wqe_raddr_seg raddr_seg;
454+
struct mlx5_wqe_atomic_seg atomic_seg_1;
455+
struct mlx5_wqe_data_seg data_seg;
456+
457+
auto ctrl_seg_ptr = reinterpret_cast<ibgda_ctrl_seg_t*>(out_wqes[0]);
458+
auto raddr_seg_ptr = reinterpret_cast<mlx5_wqe_raddr_seg*>(reinterpret_cast<uintptr_t>(ctrl_seg_ptr) + sizeof(*ctrl_seg_ptr));
459+
auto atomic_seg_ptr = reinterpret_cast<mlx5_wqe_atomic_seg*>(reinterpret_cast<uintptr_t>(raddr_seg_ptr) + sizeof(*raddr_seg_ptr));
460+
auto data_seg_ptr = reinterpret_cast<mlx5_wqe_data_seg*>(reinterpret_cast<uintptr_t>(atomic_seg_ptr) + sizeof(*atomic_seg_ptr));
461+
462+
raddr_seg.raddr = HtoBE64(raddr);
463+
raddr_seg.rkey = rkey;
464+
raddr_seg.reserved = 0;
465+
466+
// NOTES: `0x08000000` means `IBGDA_4_BYTE_EXT_AMO_OPMOD`
467+
ctrl_seg.opmod_idx_opcode = HtoBE32(MLX5_OPCODE_ATOMIC_MASKED_FA | (wqe_idx << 8) | 0x08000000);
468+
auto atomic_32_masked_fa_seg = reinterpret_cast<ibgda_atomic_32_masked_fa_seg_t*>(&atomic_seg_1);
469+
atomic_32_masked_fa_seg->add_data = HtoBE32(value);
470+
atomic_32_masked_fa_seg->field_boundary = 0;
471+
472+
ctrl_seg.qpn_ds = HtoBE32((qp->qpn << 8) | 4);
473+
ctrl_seg.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE;
474+
475+
data_seg.byte_count = HtoBE32(sizeof(int));
476+
data_seg.lkey = lkey;
477+
data_seg.addr = HtoBE64(laddr);
478+
479+
EP_STATIC_ASSERT(sizeof(*ctrl_seg_ptr) == sizeof(int4), "Invalid vectorization");
480+
EP_STATIC_ASSERT(sizeof(*raddr_seg_ptr) == sizeof(int4), "Invalid vectorization");
481+
EP_STATIC_ASSERT(sizeof(*atomic_seg_ptr) == sizeof(int4), "Invalid vectorization");
482+
EP_STATIC_ASSERT(sizeof(*data_seg_ptr) == sizeof(int4), "Invalid vectorization");
483+
st_na_relaxed(reinterpret_cast<int4*>(ctrl_seg_ptr), *reinterpret_cast<int4*>(&ctrl_seg));
484+
st_na_relaxed(reinterpret_cast<int4*>(raddr_seg_ptr), *reinterpret_cast<int4*>(&raddr_seg));
485+
st_na_relaxed(reinterpret_cast<int4*>(atomic_seg_ptr), *reinterpret_cast<int4*>(&atomic_seg_1));
486+
st_na_relaxed(reinterpret_cast<int4*>(data_seg_ptr), *reinterpret_cast<int4*>(&data_seg));
487+
}
488+
489+
__device__ __forceinline__ void nvshmemi_ibgda_amo_nonfetch_add(void *rptr, const int& value, int pe, int qp_id) {
490+
nvshmemi_ibgda_device_qp_t *qp = ibgda_get_rc(pe, qp_id);
491+
492+
__be32 rkey;
493+
uint64_t raddr;
494+
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), pe, &raddr, &rkey);
495+
496+
uint64_t my_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
497+
void *wqe_ptrs = ibgda_get_wqe_ptr(qp, my_wqe_idx);
498+
499+
ibgda_write_amo_add_wqe(qp, value, reinterpret_cast<uint64_t>(qp->ibuf.buf),
500+
qp->ibuf.lkey, raddr, rkey, my_wqe_idx, &wqe_ptrs);
501+
502+
ibgda_submit_requests<true>(qp, my_wqe_idx, 1);
503+
}
504+
442505
} // namespace deep_ep

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

+33-41
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ namespace deep_ep {
3434

3535
namespace internode_ll {
3636

37+
__global__ void barrier_all() { nvshmemx_barrier_all_block(); }
38+
39+
void barrier_all(cudaStream_t stream) {
40+
constexpr int kNumThreads = 1;
41+
42+
SETUP_LAUNCH_CONFIG(1, kNumThreads, stream);
43+
LAUNCH_KERNEL(&cfg, barrier_all);
44+
}
45+
3746
template <int kNumThreads>
3847
__launch_bounds__(kNumThreads, 1) __global__ void clean_low_latency_buffer(
3948
int* clean_0, int num_clean_int_0, int* clean_1, int num_clean_int_1) {
@@ -112,7 +121,6 @@ __global__ __launch_bounds__(
112121

113122
// Message package: hidden data, FP8 scales, index at source
114123
// NOTES: currently we have 3 reserved int fields for future use
115-
116124
using vec_t = typename std::conditional<kUseFP8, int2, int4>::type;
117125
const size_t num_bytes_per_msg =
118126
sizeof(int4) + (kUseFP8 ? (kHidden + num_scales * sizeof(float))
@@ -305,13 +313,11 @@ __global__ __launch_bounds__(
305313
responsible_expert_idx) != FINISHED_SUM_TAG * 2) {
306314
}
307315
if (dst_rank != rank) {
308-
nvshmemi_ibgda_rma_p(
316+
nvshmemi_ibgda_amo_nonfetch_add(
309317
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
310318
-num_tokens_sent - 1,
311319
dst_rank,
312-
dst_expert_local_idx,
313-
0);
314-
nvshmemi_ibgda_prepare_recvs(dst_rank, dst_expert_local_idx);
320+
dst_expert_local_idx);
315321
} else {
316322
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
317323
-num_tokens_sent - 1);
@@ -366,16 +372,9 @@ LOW_LATENCY_DISPATCH_RECV:
366372
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1,
367373
"Requires more than one warp per group");
368374
if (sub_warp_id == 1 && lane_id == 0) {
369-
if (src_rank != rank) {
370-
nvshmemi_ibgda_poll_recv(src_rank, local_expert_idx);
371-
num_recv_tokens = ld_acquire_sys_global(
372-
rdma_recv_count + local_expert_idx * num_ranks + src_rank);
373-
EP_DEVICE_ASSERT(num_recv_tokens != 0);
374-
} else {
375-
while ((num_recv_tokens = ld_acquire_global(
376-
rdma_recv_count + local_expert_idx * num_ranks +
377-
src_rank)) == 0) {
378-
}
375+
while ((num_recv_tokens = ld_acquire_global(
376+
rdma_recv_count + local_expert_idx * num_ranks + src_rank)) ==
377+
0) {
379378
}
380379
num_recv_tokens = -num_recv_tokens - 1;
381380
recv_token_begin_idx =
@@ -539,7 +538,8 @@ __global__ __launch_bounds__(
539538
int num_experts,
540539
int rank,
541540
int num_ranks,
542-
int phases) {
541+
int phases,
542+
bool zero_copy) {
543543
const auto sm_id = static_cast<int>(blockIdx.x);
544544
const auto num_sms = static_cast<int>(gridDim.x);
545545
const auto thread_id = static_cast<int>(threadIdx.x);
@@ -580,7 +580,9 @@ __global__ __launch_bounds__(
580580
const auto local_expert_idx = responsible_expert_idx % num_local_experts;
581581
const auto global_expert_idx = rank * num_local_experts + local_expert_idx;
582582
const auto layout =
583-
__ldg(layout_range + local_expert_idx * num_ranks + dst_rank);
583+
__ldg(layout_range + local_expert_idx * num_ranks +
584+
dst_rank); // num_recv_tokens, recv_token_begin_idx
585+
584586
const auto local_x = reinterpret_cast<const int4*>(x) +
585587
local_expert_idx * num_ranks *
586588
num_max_dispatch_tokens_per_rank *
@@ -625,13 +627,14 @@ __global__ __launch_bounds__(
625627
st_na_global);
626628
} else {
627629
const auto buf_int4_ptr = reinterpret_cast<int4*>(buf_ptr);
628-
UNROLLED_WARP_COPY(7,
629-
lane_id,
630-
hidden_bf16_int4,
631-
buf_int4_ptr,
632-
x_int4,
633-
ld_nc_global,
634-
st_na_global);
630+
if (!zero_copy)
631+
UNROLLED_WARP_COPY(7,
632+
lane_id,
633+
hidden_bf16_int4,
634+
buf_int4_ptr,
635+
x_int4,
636+
ld_nc_global,
637+
st_na_global);
635638
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
636639
buf_ptr,
637640
hidden * sizeof(nv_bfloat16),
@@ -651,11 +654,8 @@ __global__ __launch_bounds__(
651654
while (ld_acquire_global(atomic_clean_flag) == 0) {
652655
}
653656
if (dst_rank != rank) {
654-
nvshmemi_ibgda_rma_p(rdma_recv_flag + global_expert_idx,
655-
1,
656-
dst_rank,
657-
local_expert_idx,
658-
0);
657+
nvshmemi_ibgda_amo_nonfetch_add(
658+
rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
659659
} else {
660660
st_na_release(rdma_recv_flag + global_expert_idx, 1);
661661
}
@@ -672,18 +672,9 @@ LOW_LATENCY_COMBINE_RECV:
672672
if (responsible_expert_idx < num_experts) {
673673
EP_STATIC_ASSERT(kNumWarpsPerGroup > 1,
674674
"Invalid number of warps per group");
675-
if (sub_warp_id == 0 && lane_id == 0) {
676-
// TODO(Xreki): refactor QP indices
677-
auto src_rank = responsible_expert_idx / num_local_experts;
678-
auto src_expert_idx = responsible_expert_idx % num_local_experts;
679-
if (src_rank != rank) {
680-
nvshmemi_ibgda_poll_recv(src_rank, src_expert_idx);
681-
} else {
682-
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) ==
683-
0) {
684-
}
675+
if (sub_warp_id == 0 && lane_id == 0)
676+
while (ld_acquire_global(rdma_recv_flag + responsible_expert_idx) == 0) {
685677
}
686-
}
687678
}
688679
cg::this_grid().sync();
689680

@@ -796,7 +787,8 @@ void combine(void* combined_x,
796787
num_experts, \
797788
rank, \
798789
num_ranks, \
799-
phases); \
790+
phases, \
791+
false); \
800792
} \
801793
break
802794

paddle/fluid/framework/ir/xpu/weight_only_linear_xpu_pass.cc

+6-6
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ PermuteINT8WeightOnlyPattern::PermuteINT8WeightOnlyPattern(
4444
PDPattern* pattern, const std::string& name_scope)
4545
: PatternBase(pattern, name_scope, name_scope) {
4646
auto* input = pattern->NewNode(input_repr())
47-
->assert_is_op_input("weight_only_linear_xpu", "x")
47+
->assert_is_op_input("weight_only_linear", "x")
4848
->AsInput();
4949
auto* weight = pattern->NewNode(weight_repr())
50-
->assert_is_op_input("weight_only_linear_xpu", "weight")
50+
->assert_is_op_input("weight_only_linear", "weight")
5151
->AsInput();
5252
auto* weight_scale =
5353
pattern->NewNode(weight_scale_repr())
54-
->assert_is_op_input("weight_only_linear_xpu", "weight_scale")
54+
->assert_is_op_input("weight_only_linear", "weight_scale")
5555
->AsInput();
5656
auto* out = pattern->NewNode(out_repr())
57-
->assert_is_op_output("weight_only_linear_xpu", "out")
57+
->assert_is_op_output("weight_only_linear", "out")
5858
->AsOutput();
5959
auto* weight_only_linear = pattern->NewNode(weight_only_linear_repr())
60-
->assert_is_op("weight_only_linear_xpu");
60+
->assert_is_op("weight_only_linear");
6161

6262
std::vector<PDNode*> input_vars{input, weight, weight_scale};
6363
std::vector<PDNode*> output_vars{out};
@@ -236,4 +236,4 @@ REGISTER_PASS(weight_only_linear_xpu_pass,
236236
REGISTER_PASS_CAPABILITY(weight_only_linear_xpu_pass)
237237
.AddCombination(
238238
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
239-
"weight_only_linear_xpu", 0));
239+
"weight_only_linear", 0));

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

+18-3
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
107107

108108
// only gpu need update. xpu not need, because xpu memcpy op kernel is
109109
// synchronous.
110-
if (phi::is_gpu_place(place) || phi::is_custom_place(place)) {
110+
if (phi::is_gpu_place(place) || phi::is_custom_place(place) ||
111+
phi::is_xpu_place(place)) {
111112
VLOG(6) << "Parse DeviceContext for " << op_name
112113
<< ", execution stream = " << execution_stream;
113114
if (execution_stream != kDefaultStream) {
@@ -136,7 +137,7 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
136137
}
137138

138139
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
139-
defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
140+
defined(PADDLE_WITH_CUSTOM_DEVICE) || defined(PADDLE_WITH_XPU_BKCL)
140141
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
141142
// with use_cal_stream==false by returning a device context getting from the
142143
// global NCCLCommContext instance. Because when use_calc_stream==false, in
@@ -205,7 +206,21 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
205206
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
206207
op_name.compare(
207208
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
208-
#ifdef PADDLE_WITH_CUSTOM_DEVICE
209+
#if defined(PADDLE_WITH_XPU_BKCL)
210+
if (phi::is_xpu_place(place) && execution_stream == kDefaultStream) {
211+
VLOG(3) << "set stream for " << op_name << "in XPU device";
212+
if (origin_dev_ctx != nullptr) {
213+
// set stream
214+
auto default_stream =
215+
static_cast<DEVICE_CONTEXT*>(origin_dev_ctx)->stream();
216+
static_cast<DEVICE_CONTEXT*>(dev_ctx)->SetStream(default_stream);
217+
// todo set allocator
218+
} else {
219+
VLOG(3) << "CUSTOM DEVICE op " << op_name << " ring_id "
220+
<< ring_id << " origin_dev_ctx is nullptr";
221+
}
222+
}
223+
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
209224
if (phi::is_custom_place(place) &&
210225
execution_stream == kDefaultStream) {
211226
VLOG(3) << "set stream for " << op_name << "in Custom device";

paddle/fluid/pybind/deep_ep_api.cc

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void BindDeepEPApi(pybind11::module *m) {
8989
.def("intranode_combine", &deep_ep::Buffer::intranode_combine_api)
9090
.def("internode_dispatch", &deep_ep::Buffer::internode_dispatch_api)
9191
.def("internode_combine", &deep_ep::Buffer::internode_combine_api)
92+
.def("barrier_all", &deep_ep::Buffer::barrier_all)
9293
.def("clean_low_latency_buffer",
9394
&deep_ep::Buffer::clean_low_latency_buffer)
9495
.def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch_api)

paddle/phi/backends/xpu/xpu2_op_list.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ XPUOpMap& get_kl2_ops() {
12201220
phi::DataType::FLOAT32})},
12211221
{"warpctc_grad", XPUKernelSet({phi::DataType::FLOAT32})},
12221222
{"warpctc", XPUKernelSet({phi::DataType::FLOAT32})},
1223-
{"weight_only_linear_xpu",
1223+
{"weight_only_linear",
12241224
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
12251225
{"where_index",
12261226
XPUKernelSet({phi::DataType::INT32,

paddle/phi/backends/xpu/xpu3_op_list.cc

+2
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,8 @@ XPUOpMap& get_kl3_ops() {
16901690
phi::DataType::BOOL,
16911691
phi::DataType::FLOAT32,
16921692
phi::DataType::INT64})},
1693+
{"weight_quantize",
1694+
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::BFLOAT16})},
16931695
{"where_grad",
16941696
XPUKernelSet({phi::DataType::INT32,
16951697
phi::DataType::INT64,

0 commit comments

Comments
 (0)