Skip to content

Commit 929ae8b

Browse files
committed
add check
1 parent 3a2eab5 commit 929ae8b

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

+6
Original file line numberDiff line numberDiff line change
@@ -3588,6 +3588,12 @@ std::vector<pir::Type> ExpandOp::InferMeta(
35883588
.dyn_cast<paddle::dialect::IntArrayAttribute>()
35893589
.data()
35903590
.GetData();
3591+
PADDLE_ENFORCE_LE(shape_vec.size(),
3592+
1,
3593+
common::errors::InvalidArgument(
3594+
"The size of shape for Full op should be less than "
3595+
"or equal to 1, but receive %d.",
3596+
shape_vec.size()));
35913597
auto items = shape_vec.empty() ? 1 : shape_vec[0];
35923598
vec_shape = std::vector<int64_t>(items, shape_item);
35933599
} else if (shape.isa<pir::OpResult>() &&

paddle/fluid/pir/dialect/operator/utils/utils.cc

+6
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,12 @@ std::vector<int64_t> ParseValueShape(const pir::Value& shape,
432432
.dyn_cast<paddle::dialect::IntArrayAttribute>()
433433
.data()
434434
.GetData();
435+
PADDLE_ENFORCE_LE(shape_vec.size(),
436+
1,
437+
common::errors::InvalidArgument(
438+
"The size of shape for Full op should be less than "
439+
"or equal to 1, but receive %d.",
440+
shape_vec.size()));
435441
auto items = shape_vec.empty() ? 1 : shape_vec[0];
436442
vec_shape = std::vector<int64_t>(items, shape_item);
437443
} else if (shape.isa<pir::OpResult>() &&

0 commit comments

Comments
 (0)