Skip to content

Commit 4ec9033

Browse files
authored
Optimize the overhead for all_elementwise case in BroadcastKernel. (PaddlePaddle#57247)
1 parent 744fca0 commit 4ec9033

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,13 @@ struct BroadcastTypeClassifier {
7474
void InitBroadcastConfigs(const std::vector<const DenseTensor *> &ins,
7575
std::vector<DenseTensor *> *outs,
7676
int axis) {
77+
#ifdef PADDLE_WITH_XPU_KP
7778
const auto dims_simplifier =
7879
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
7980
if (VLOG_IS_ON(6)) {
8081
DimsSimplifiedLogger<int64_t>::Log(
8182
ins, outs, dims_simplifier, "BroadcastKernel");
8283
}
83-
84-
#ifdef PADDLE_WITH_XPU_KP
8584
configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
8685
dims_simplifier.in_dims[0],
8786
dims_simplifier.in_dims[1],
@@ -92,6 +91,12 @@ struct BroadcastTypeClassifier {
9291
dims_simplifier.rank);
9392
#else
9493
if (!all_elementwise) {
94+
const auto dims_simplifier =
95+
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
96+
if (VLOG_IS_ON(6)) {
97+
DimsSimplifiedLogger<int64_t>::Log(
98+
ins, outs, dims_simplifier, "BroadcastKernel");
99+
}
95100
for (int i = 0; i < Arity; ++i) {
96101
// if data shape is[m, n], then you should set data_dim = {n, m}
97102
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}

0 commit comments

Comments
 (0)