@@ -1647,8 +1647,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
1647
1647
1648
1648
void ElementwiseInferMeta (const MetaTensor& x,
1649
1649
const MetaTensor& y,
1650
- MetaTensor* out) {
1651
- return ElementwiseRawInferMeta (x, y, -1 , out);
1650
+ MetaTensor* out,
1651
+ MetaConfig config) {
1652
+ return ElementwiseRawInferMeta (x, y, -1 , out, config);
1652
1653
}
1653
1654
1654
1655
void BitwiseShiftInferMeta (const MetaTensor& x,
@@ -1691,6 +1692,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
1691
1692
std::vector<int > x_dims_array (max_dim);
1692
1693
std::vector<int > y_dims_array (max_dim);
1693
1694
std::vector<int > out_dims_array (max_dim);
1695
+
1694
1696
#ifdef PADDLE_WITH_DNNL
1695
1697
bool should_rotate =
1696
1698
config.is_run_mkldnn_kernel &&
@@ -1699,7 +1701,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
1699
1701
(x_dims.size () >= 3 || y_dims.size () >= 3 );
1700
1702
if (should_rotate) {
1701
1703
// Pick bigger shape and rotate this one
1702
- bool x_over_y = (x_dims. size ( ) > y_dims. size ( ));
1704
+ bool x_over_y = (common::product (x_dims ) > common::product (y_dims ));
1703
1705
auto vdims = x_over_y ? common::vectorize<int >(x_dims)
1704
1706
: common::vectorize<int >(y_dims);
1705
1707
std::rotate (vdims.begin () + 1 , vdims.begin () + 2 , vdims.end ());
0 commit comments