Skip to content

Commit 9e9b38f

Browse files
add arg check for cummax (#59979)
* add arg check for cummax
1 parent e2a722f commit 9e9b38f

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

paddle/phi/infermeta/unary.cc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,21 +590,6 @@ void CumWithIndicesInferMeta(const MetaTensor& x,
590590
phi::errors::InvalidArgument(
591591
"dtype of indices must be DataType::INT32 or DataType::INT64"));
592592

593-
if (dtype == DataType::INT32) {
594-
int _axis = 0;
595-
if (axis < 0) {
596-
_axis = axis + x_dims.size();
597-
} else {
598-
_axis = axis;
599-
}
600-
PADDLE_ENFORCE_LT(
601-
common::vectorize(x_dims)[_axis],
602-
INT32_MAX,
603-
phi::errors::OutOfRange(
604-
"cummax with axis %ld may be overflow, set dtype int64 to continue",
605-
axis));
606-
}
607-
608593
if (x_dims.size() > 0) {
609594
PADDLE_ENFORCE_GE(
610595
axis,
@@ -633,6 +618,21 @@ void CumWithIndicesInferMeta(const MetaTensor& x,
633618
axis));
634619
}
635620

621+
if (dtype == DataType::INT32) {
622+
int _axis = 0;
623+
if (axis < 0) {
624+
_axis = axis + x_dims.size();
625+
} else {
626+
_axis = axis;
627+
}
628+
PADDLE_ENFORCE_LT(
629+
common::vectorize(x_dims)[_axis],
630+
INT32_MAX,
631+
phi::errors::OutOfRange(
632+
"cummax with axis %ld may be overflow, set dtype int64 to continue",
633+
axis));
634+
}
635+
636636
out->set_dims(x_dims);
637637
out->set_dtype(x.dtype());
638638
out->share_lod(x);

0 commit comments

Comments
 (0)