Skip to content

Commit 7816a00

Browse files
committed
remove check temporarily
1 parent 929ae8b commit 7816a00

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

+6-7
Original file line numberDiff line numberDiff line change
@@ -3588,13 +3588,12 @@ std::vector<pir::Type> ExpandOp::InferMeta(
35883588
.dyn_cast<paddle::dialect::IntArrayAttribute>()
35893589
.data()
35903590
.GetData();
3591-
PADDLE_ENFORCE_LE(shape_vec.size(),
3592-
1,
3593-
common::errors::InvalidArgument(
3594-
"The size of shape for Full op should be less than "
3595-
"or equal to 1, but receive %d.",
3596-
shape_vec.size()));
3597-
auto items = shape_vec.empty() ? 1 : shape_vec[0];
3591+
// TODO(ooooo): If can make sure shape_value's size is less than or equal
3592+
// to 1, can add a check here rather than product.
3593+
int64_t items = 1;
3594+
for (const auto &item : shape_vec) {
3595+
items *= item;
3596+
}
35983597
vec_shape = std::vector<int64_t>(items, shape_item);
35993598
} else if (shape.isa<pir::OpResult>() &&
36003599
shape.defining_op()->isa<paddle::dialect::StackOp>()) {

paddle/fluid/pir/dialect/operator/utils/utils.cc

+7-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include <glog/logging.h>
16+
#include <cstdint>
1617
#include <sstream>
1718
#include <unordered_set>
1819

@@ -432,13 +433,12 @@ std::vector<int64_t> ParseValueShape(const pir::Value& shape,
432433
.dyn_cast<paddle::dialect::IntArrayAttribute>()
433434
.data()
434435
.GetData();
435-
PADDLE_ENFORCE_LE(shape_vec.size(),
436-
1,
437-
common::errors::InvalidArgument(
438-
"The size of shape for Full op should be less than "
439-
"or equal to 1, but receive %d.",
440-
shape_vec.size()));
441-
auto items = shape_vec.empty() ? 1 : shape_vec[0];
436+
// TODO(ooooo): If can make sure shape_value's size is less than or equal
437+
// to 1, can add a check here rather than product.
438+
int64_t items = 1;
439+
for (const auto& item : shape_vec) {
440+
items *= item;
441+
}
442442
vec_shape = std::vector<int64_t>(items, shape_item);
443443
} else if (shape.isa<pir::OpResult>() &&
444444
shape.defining_op()->isa<paddle::dialect::StackOp>()) {

0 commit comments

Comments
 (0)