diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 9ea24832a175cf..2ba3271d2c7df6 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -74,14 +74,13 @@ struct BroadcastTypeClassifier { void InitBroadcastConfigs(const std::vector &ins, std::vector *outs, int axis) { +#ifdef PADDLE_WITH_XPU_KP const auto dims_simplifier = BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); if (VLOG_IS_ON(6)) { DimsSimplifiedLogger::Log( ins, outs, dims_simplifier, "BroadcastKernel"); } - -#ifdef PADDLE_WITH_XPU_KP configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims, dims_simplifier.in_dims[0], dims_simplifier.in_dims[1], @@ -92,6 +91,12 @@ struct BroadcastTypeClassifier { dims_simplifier.rank); #else if (!all_elementwise) { + const auto dims_simplifier = + BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); + if (VLOG_IS_ON(6)) { + DimsSimplifiedLogger::Log( + ins, outs, dims_simplifier, "BroadcastKernel"); + } for (int i = 0; i < Arity; ++i) { // if data shape is[m, n], then you should set data_dim = {n, m} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}