@@ -74,14 +74,13 @@ struct BroadcastTypeClassifier {
74
74
void InitBroadcastConfigs (const std::vector<const DenseTensor *> &ins,
75
75
std::vector<DenseTensor *> *outs,
76
76
int axis) {
77
+ #ifdef PADDLE_WITH_XPU_KP
77
78
const auto dims_simplifier =
78
79
BroadcastDimsSimplifier (ins, (*outs)[0 ]->dims (), axis);
79
80
if (VLOG_IS_ON (6 )) {
80
81
DimsSimplifiedLogger<int64_t >::Log (
81
82
ins, outs, dims_simplifier, " BroadcastKernel" );
82
83
}
83
-
84
- #ifdef PADDLE_WITH_XPU_KP
85
84
configs[0 ] = kps::details::BroadcastConfig (dims_simplifier.out_dims ,
86
85
dims_simplifier.in_dims [0 ],
87
86
dims_simplifier.in_dims [1 ],
@@ -92,6 +91,12 @@ struct BroadcastTypeClassifier {
92
91
dims_simplifier.rank );
93
92
#else
94
93
if (!all_elementwise) {
94
+ const auto dims_simplifier =
95
+ BroadcastDimsSimplifier (ins, (*outs)[0 ]->dims (), axis);
96
+ if (VLOG_IS_ON (6 )) {
97
+ DimsSimplifiedLogger<int64_t >::Log (
98
+ ins, outs, dims_simplifier, " BroadcastKernel" );
99
+ }
95
100
for (int i = 0 ; i < Arity; ++i) {
96
101
// if data shape is[m, n], then you should set data_dim = {n, m}
97
102
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
0 commit comments