Skip to content

Commit ae4d1ec

Browse files
Modified reduce for xpu2 (#42439)
1 parent 8b546f1 commit ae4d1ec

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

paddle/phi/kernels/funcs/reduce_function.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ struct ReduceConfig {
473473
bool not_higher = x_dim[0] >= max_grid_z;
474474
#endif
475475
if (reduce_last_dim && (reduce_rank == 1)) {
476+
#ifdef PADDLE_WITH_XPU_KP
477+
reduce_type = static_cast<int>(ReduceType::kReduceAny);
478+
#else
476479
reduce_type = static_cast<int>(ReduceType::kReduceLastDim);
480+
#endif
477481
} else if (reduce_rank == 1) {
478482
reduce_type = static_cast<int>(ReduceType::kReduceHigherDim);
479483
if (rank == 3 && not_higher) {
@@ -588,7 +592,7 @@ struct ReduceConfig {
588592
void SetBlockDim() {
589593
// init
590594
should_reduce_again = false;
591-
dim3 block_dim;
595+
dim3 block_dim(1, 1, 1);
592596
dim3 grid_dim(left_num, 1, 1);
593597
blocking_size = reduce_num;
594598

paddle/phi/kernels/primitive/compute_primitives_xpu2.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,12 @@ __device__ __forceinline__ void Reduce(T* out,
329329
ReduceFunctor reducer,
330330
bool reduce_last_dim) {
331331
if (Mode == details::kGlobalMode) {
332+
if (reduce_last_dim) {
332333
#pragma unroll
333-
for (int i = 0; i < NY; ++i) {
334-
#pragma unroll
335-
for (int j = 0; j < NX; ++j) {
336-
out[i] = reducer(out[i], in[i * NX + j]);
334+
for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
335+
details::BlockXReduce<T, ReduceFunctor, 1>(&out[i], reducer);
337336
}
338337
}
339-
details::BlockXReduce<T, ReduceFunctor, NY>(out, reducer);
340338
} else { // else kLocalMode
341339
#pragma unroll
342340
for (int i = 0; i < NY; ++i) {

0 commit comments

Comments
 (0)