Skip to content

Commit bcb9879

Browse files
fix add infermata (#66875)
* fix add infermata * refine * refine
1 parent 71ab427 commit bcb9879

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,8 +1647,9 @@ void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
16471647

16481648
void ElementwiseInferMeta(const MetaTensor& x,
16491649
const MetaTensor& y,
1650-
MetaTensor* out) {
1651-
return ElementwiseRawInferMeta(x, y, -1, out);
1650+
MetaTensor* out,
1651+
MetaConfig config) {
1652+
return ElementwiseRawInferMeta(x, y, -1, out, config);
16521653
}
16531654

16541655
void BitwiseShiftInferMeta(const MetaTensor& x,
@@ -1691,6 +1692,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
16911692
std::vector<int> x_dims_array(max_dim);
16921693
std::vector<int> y_dims_array(max_dim);
16931694
std::vector<int> out_dims_array(max_dim);
1695+
16941696
#ifdef PADDLE_WITH_DNNL
16951697
bool should_rotate =
16961698
config.is_run_mkldnn_kernel &&
@@ -1699,7 +1701,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
16991701
(x_dims.size() >= 3 || y_dims.size() >= 3);
17001702
if (should_rotate) {
17011703
// Pick bigger shape and rotate this one
1702-
bool x_over_y = (x_dims.size() > y_dims.size());
1704+
bool x_over_y = (common::product(x_dims) > common::product(y_dims));
17031705
auto vdims = x_over_y ? common::vectorize<int>(x_dims)
17041706
: common::vectorize<int>(y_dims);
17051707
std::rotate(vdims.begin() + 1, vdims.begin() + 2, vdims.end());

paddle/phi/infermeta/binary.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,8 @@ void DropoutNdInferMeta(const MetaTensor& x,
324324

325325
TEST_API void ElementwiseInferMeta(const MetaTensor& x,
326326
const MetaTensor& y,
327-
MetaTensor* out);
327+
MetaTensor* out,
328+
MetaConfig config = MetaConfig());
328329

329330
void ElementwiseRawInferMeta(const MetaTensor& x_meta,
330331
const MetaTensor& y_meta,

0 commit comments

Comments
 (0)