From 6747af97ef8666fd87cb28b8b45fda4d935da18a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 21 Apr 2025 12:36:34 +0000 Subject: [PATCH] refine infer_symbol_shape of unique_consecutive op --- .../infer_symbolic_shape/unary_infer_sym.cc | 46 ++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 0d984c6678a9e..c2312117f7edc 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -4170,15 +4170,9 @@ bool UniqueOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); - PADDLE_ENFORCE_EQ( - x_shape_or_data.data().has_value(), - false, - common::errors::InvalidArgument( - "InferSymbolicShape of UniqueOp only support input with " - "value now.")); const auto &x_dims_sym = x_shape_or_data.shape(); const size_t rank = x_dims_sym.size(); - std::vector axes = + const std::vector axes = paddle::dialect::details::GetVectorAttr(op, "axis"); symbol::DimExpr unique_dim_sym = @@ -4246,8 +4240,40 @@ bool UniqueConsecutiveOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_dims_sym = x_shape_or_data.shape(); const size_t rank = x_dims_sym.size(); - std::vector axes = + const std::vector axes = paddle::dialect::details::GetVectorAttr(op, "axis"); + const bool return_inverse = GetBoolAttr(op, "return_inverse"); + const bool return_counts = GetBoolAttr(op, "return_counts"); + symbol::ShapeOrDataDimExprs empty{symbol::TensorShapeOrDataDimExprs{}}; + + // x has data + if (x_shape_or_data.data().has_value() && (rank == 1 || axes.empty())) { + const auto &x_data = x_shape_or_data.data().value(); + const bool is_all_const = [&] { + for (const auto &x_value : x_data) { + if (!x_value.isa()) return false; + } + return true; + }(); + if (is_all_const) { + auto out_data = x_data; + auto last = std::unique(out_data.begin(), out_data.end()); + out_data.erase(last, out_data.end()); + const std::vector out_size{ + static_cast(out_data.size())}; + + infer_context->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs{out_size, out_data}); + infer_context->SetShapeOrDataForValue( + op->result(1), + return_inverse ? symbol::TensorShapeOrDataDimExprs{x_dims_sym} + : empty); + infer_context->SetShapeOrDataForValue( + op->result(2), + return_counts ? symbol::TensorShapeOrDataDimExprs{out_size} : empty); + return true; + } + } symbol::DimExpr unique_dim_sym = infer_context->GetNextSymName(); // unknown until runtime @@ -4286,10 +4312,6 @@ bool UniqueConsecutiveOpInferSymbolicShape( return inverse_dims; }(); - bool return_inverse = GetBoolAttr(op, "return_inverse"); - bool return_counts = GetBoolAttr(op, "return_counts"); - - symbol::ShapeOrDataDimExprs empty{symbol::TensorShapeOrDataDimExprs{}}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs{out_dims}); infer_context->SetShapeOrDataForValue(