Skip to content

Commit c6d6f9f

Browse files
committed
[XPU] update xhpc to impove performance of strided_copy, update interface of prelu and rsqrt
1 parent 988669d commit c6d6f9f

File tree

8 files changed

+68
-114
lines changed

8 files changed

+68
-114
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ set(XPU_FFT_LIB_NAME "libcufft.so")
3333
add_compile_definitions(XPUAPI_NOT_INCLUDE_DEPRECATED)
3434

3535
if(NOT DEFINED XPU_XHPC_BASE_DATE)
36-
set(XPU_XHPC_BASE_DATE "dev/20250417")
36+
set(XPU_XHPC_BASE_DATE "dev/20250520")
3737
endif()
3838
set(XPU_XCCL_BASE_VERSION "3.0.2.5") # For XRE5
3939
if(NOT DEFINED XPU_XFT_BASE_VERSION)

paddle/phi/kernels/xpu/activation_grad_kernel.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,19 @@ struct XPURsqrtGradFunctor : public funcs::BaseActivationFunctor<T> {
606606
const DenseTensor* out,
607607
const DenseTensor* dout,
608608
DenseTensor* dx) const {
609-
int r = xpu_activation_backward<Context, T, XPUType>(
610-
dev_ctx, x, out, dout, dx, xpu::rsqrt_grad<XPUType>);
609+
dev_ctx.template Alloc<T>(dx);
610+
const XPUType* out_data = nullptr;
611+
const XPUType* dout_data = nullptr;
612+
if (out != nullptr) {
613+
out_data = reinterpret_cast<const XPUType*>(out->data<T>());
614+
}
615+
if (dout != nullptr) {
616+
dout_data = reinterpret_cast<const XPUType*>(dout->data<T>());
617+
}
618+
XPUType* dx_data = reinterpret_cast<XPUType*>(dx->data<T>());
619+
620+
int r = xpu::rsqrt_grad(
621+
dev_ctx.x_context(), out_data, dout_data, dx_data, dx->numel());
611622
PADDLE_ENFORCE_XDNN_SUCCESS(r, "rsqrt_grad");
612623
}
613624
};

paddle/phi/kernels/xpu/flash_attn_utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,3 @@ static void GenerateRNGState(
8484
}
8585
}
8686
} // namespace phi
87-
#

paddle/phi/kernels/xpu/p_recv_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void PRecvKernel(const Context& dev_ctx,
5353
#else
5454
PADDLE_THROW(common::errors::PreconditionNotMet(
5555
"PaddlePaddle is not compiled with DWITH_XPU_BKCL, please recompile with "
56-
"DWITH_XPU_BKCL for using p_recv_kernel."));
56+
"DWITH_XPU_BKCL for using p_recv kernel."));
5757
#endif
5858
}
5959

@@ -80,7 +80,7 @@ void PRecvArrayKernel(const Context& dev_ctx,
8080
#else
8181
PADDLE_THROW(common::errors::PreconditionNotMet(
8282
"PaddlePaddle is not compiled with DWITH_XPU_BKCL, please recompile with "
83-
"DWITH_XPU_BKCL for using p_recv_kernel."));
83+
"DWITH_XPU_BKCL for using p_recv_array kernel."));
8484
#endif
8585
}
8686

paddle/phi/kernels/xpu/p_send_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void PSendKernel(const Context& dev_ctx,
4545
#else
4646
PADDLE_THROW(common::errors::PreconditionNotMet(
4747
"PaddlePaddle is not compiled with DWITH_XPU_BKCL, please recompile with "
48-
"DWITH_XPU_BKCL for using p_send_kernel."));
48+
"DWITH_XPU_BKCL for using p_send kernel."));
4949
#endif
5050
}
5151

@@ -68,7 +68,7 @@ void PSendArrayKernel(const Context& dev_ctx,
6868
#else
6969
PADDLE_THROW(common::errors::PreconditionNotMet(
7070
"PaddlePaddle is not compiled with DWITH_XPU_BKCL, please recompile with "
71-
"DWITH_XPU_BKCL for using p_send_kernel."));
71+
"DWITH_XPU_BKCL for using p_send_array kernel."));
7272
#endif
7373
}
7474

paddle/phi/kernels/xpu/prelu_grad_kernel.cc

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,48 +38,48 @@ void PReluGradKernel(const Context& dev_ctx,
3838

3939
auto x_dim = x.dims();
4040
auto x_rank = x_dim.size();
41-
4241
std::vector<int64_t> x_shape(x_rank);
4342
if (x_rank == 0) {
4443
x_shape = std::vector<int64_t>({1});
4544
} else {
46-
for (int i = 0; i < x_rank; i++) {
47-
x_shape[i] = x_dim[i];
48-
}
45+
x_shape = common::vectorize<int64_t>(x_dim);
4946
}
5047

51-
// mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xshape = {n,
52-
// c, h, w}
53-
// mode = 1, channel_nhwc, slope_shape = {c}, meanwhile, xshape = {n, h, w, c}
54-
// mode = 2, elementwise, slope_shape = {c*h*w}
55-
// mode = 3, single slope, slope_shape = {1}
48+
// mode = 0: channel_nchw, xshape = {n, c, h, w}, alpha_shape = {c}
49+
// mode = 1, channel_nhwc, xshape = {n, h, w, c}, alpha_shape = {c}
50+
// mode = 2, elementwise, deprecated in Paddle 2.x
51+
// mode = 3, alpha_shape = {} or {1}
5652

5753
int xpu_mode = 0;
5854

5955
if (mode == "channel") {
6056
if (data_format == "NCHW") {
6157
xpu_mode = 0;
62-
} else {
63-
// NHWC
58+
if (x_rank == 2) { // special case for NC shape, use channel last mode
59+
xpu_mode == 1;
60+
}
61+
} else { // NHWC, channel last
6462
xpu_mode = 1;
6563
}
6664
} else if (mode == "element") {
6765
xpu_mode = 2;
68-
} else {
66+
} else if (mode == "all") {
6967
xpu_mode = 3;
68+
} else {
69+
PADDLE_THROW(common::errors::InvalidArgument(
70+
"Expected mode of prelu kernel is 'channel' or 'all', But got "
71+
"unsupported mode: %s.",
72+
mode));
7073
}
7174

72-
int r = xpu::prelu_grad(
73-
dev_ctx.x_context(),
74-
reinterpret_cast<const XPUType*>(x_ptr),
75-
reinterpret_cast<const XPUType*>(
76-
out_grad_ptr), /* const T* y, not used in xpu kernel */
77-
reinterpret_cast<const XPUType*>(alpha_ptr),
78-
reinterpret_cast<const XPUType*>(out_grad_ptr),
79-
reinterpret_cast<XPUType*>(x_grad_ptr),
80-
reinterpret_cast<XPUType*>(alpha_grad_ptr),
81-
x_shape,
82-
xpu_mode);
75+
int r = xpu::prelu_grad(dev_ctx.x_context(),
76+
reinterpret_cast<const XPUType*>(x_ptr),
77+
reinterpret_cast<const XPUType*>(alpha_ptr),
78+
reinterpret_cast<const XPUType*>(out_grad_ptr),
79+
reinterpret_cast<XPUType*>(x_grad_ptr),
80+
reinterpret_cast<XPUType*>(alpha_grad_ptr),
81+
x_shape,
82+
xpu_mode);
8383

8484
PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu_grad");
8585
}

paddle/phi/kernels/xpu/prelu_kernel.cc

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,33 +35,45 @@ void PReluKernel(const Context& dev_ctx,
3535
auto x_dim = x.dims();
3636
auto x_rank = x_dim.size();
3737
std::vector<int64_t> x_shape(x_rank);
38-
3938
if (x_rank == 0) {
4039
x_shape = std::vector<int64_t>({1});
4140
} else {
42-
for (int i = 0; i < x_rank; i++) {
43-
x_shape[i] = x_dim[i];
44-
}
41+
x_shape = common::vectorize<int64_t>(x_dim);
4542
}
4643

47-
auto alpha_dim = alpha.dims();
48-
auto alpha_rank = alpha_dim.size();
49-
std::vector<int64_t> alpha_shape(x_rank, 1); // same size with x_shape
44+
// mode = 0: channel_nchw, xshape = {n, c, h, w}, alpha_shape = {c}
45+
// mode = 1, channel_nhwc, xshape = {n, h, w, c}, alpha_shape = {c}
46+
// mode = 2, elementwise, deprecated in Paddle 2.x
47+
// mode = 3, alpha_shape = {} or {1}
5048

51-
if (x_rank == 0) {
52-
alpha_shape = std::vector<int64_t>({1});
53-
} else {
54-
for (int i = 0; i < alpha_rank; i++) {
55-
alpha_shape[i] = alpha_dim[i];
49+
int xpu_mode = 0;
50+
51+
if (mode == "channel") {
52+
if (data_format == "NCHW") {
53+
xpu_mode = 0;
54+
if (x_rank == 2) { // special case for NC shape, use channel last mode
55+
xpu_mode == 1;
56+
}
57+
} else { // NHWC, channel last
58+
xpu_mode = 1;
5659
}
60+
} else if (mode == "element") {
61+
xpu_mode = 2;
62+
} else if (mode == "all") {
63+
xpu_mode = 3;
64+
} else {
65+
PADDLE_THROW(common::errors::InvalidArgument(
66+
"Expected mode of prelu kernel is 'channel' or 'all', But got "
67+
"unsupported mode: %s.",
68+
mode));
5769
}
5870

5971
int r = xpu::prelu(dev_ctx.x_context(),
6072
reinterpret_cast<const XPUType*>(x_ptr),
6173
reinterpret_cast<const XPUType*>(alpha_ptr),
6274
reinterpret_cast<XPUType*>(y_ptr),
6375
x_shape,
64-
alpha_shape);
76+
xpu_mode);
6577

6678
PADDLE_ENFORCE_XDNN_SUCCESS(r, "prelu");
6779
}

paddle/phi/kernels/xpu/strided_copy_kernel.cc

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ void StridedCopyKernel(const Context& dev_ctx,
5454
"StridedCopyKernel's out tensor must complete "
5555
"mutable data before call kernel."));
5656

57-
// The following XPU operators have performance issues and are temporarily
58-
// disabled. A temporary workaround has been implemented: "First copy data to
59-
// CPU, perform computation using CPU operator logic, then copy results back
60-
// to XPU".
61-
/*
6257
// use XPUCopyTypeTrait to deal with double and int16_t copy instead of
6358
// XPUTypeTrait
6459
using XPUType = typename XPUCopyTypeTrait<T>::Type;
@@ -74,80 +69,17 @@ void StridedCopyKernel(const Context& dev_ctx,
7469
r = xpu::copy<XPUType>(dev_ctx.x_context(), input_data, output_data, 1);
7570
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
7671
} else {
72+
int64_t data_size = input.Holder()->size() - input.meta().offset;
7773
r = xpu::strided_copy<XPUType>(dev_ctx.x_context(),
7874
input_data,
7975
output_data,
76+
data_size,
8077
common::vectorize<int64_t>(input.dims()),
8178
common::vectorize<int64_t>(out->dims()),
8279
common::vectorize<int64_t>(input.strides()),
8380
common::vectorize<int64_t>(out->strides()));
8481
PADDLE_ENFORCE_XDNN_SUCCESS(r, "strided_copy");
8582
}
86-
*/
87-
88-
// wait before copy
89-
dev_ctx.Wait();
90-
91-
// CPU buffer for input
92-
char* input_on_cpu = new char[input.Holder()->size()];
93-
memory_utils::Copy(CPUPlace(),
94-
static_cast<void*>(input_on_cpu),
95-
dev_ctx.GetPlace(),
96-
static_cast<const void*>(input.Holder()->ptr()),
97-
input.Holder()->size());
98-
99-
// CPU buffer for out
100-
char* output_on_cpu = new char[out->Holder()->size()];
101-
memory_utils::Copy(CPUPlace(),
102-
static_cast<void*>(output_on_cpu),
103-
dev_ctx.GetPlace(),
104-
static_cast<const void*>(out->Holder()->ptr()),
105-
out->Holder()->size());
106-
107-
// wait after copy
108-
dev_ctx.Wait();
109-
110-
// follow paddle/phi/kernels/cpu/strided_copy_kernel.cc
111-
const T* input_data =
112-
reinterpret_cast<T*>(input_on_cpu + input.meta().offset);
113-
int input_rank = input.dims().size();
114-
const int64_t* input_dims = input.dims().Get();
115-
const int64_t* input_stride = input.strides().Get();
116-
117-
T* output_data = reinterpret_cast<T*>(output_on_cpu + offset);
118-
int output_rank = meta.dims.size();
119-
const int64_t* output_dims = meta.dims.Get();
120-
const int64_t* output_stride = meta.strides.Get();
121-
122-
auto numel = input.numel();
123-
124-
for (int64_t i = 0; i < numel; i++) {
125-
int64_t input_offset = 0;
126-
int64_t index_tmp = i;
127-
for (int dim = input_rank - 1; dim >= 0; --dim) {
128-
input_offset += (index_tmp % input_dims[dim]) * input_stride[dim];
129-
index_tmp = index_tmp / input_dims[dim];
130-
}
131-
int64_t output_offset = 0;
132-
index_tmp = i;
133-
for (int dim = output_rank - 1; dim >= 0; --dim) {
134-
output_offset += (index_tmp % output_dims[dim]) * output_stride[dim];
135-
index_tmp = index_tmp / output_dims[dim];
136-
}
137-
output_data[output_offset] = input_data[input_offset];
138-
}
139-
140-
// copy out tensor, from cpu to xpu
141-
memory_utils::Copy(dev_ctx.GetPlace(),
142-
static_cast<void*>(out->Holder()->ptr()),
143-
CPUPlace(),
144-
static_cast<const void*>(output_on_cpu),
145-
out->Holder()->size());
146-
// wait after copy
147-
dev_ctx.Wait();
148-
149-
delete[] input_on_cpu;
150-
delete[] output_on_cpu;
15183
}
15284

15385
} // namespace phi

0 commit comments

Comments
 (0)