Skip to content

Commit 36590f6

Browse files
[CINN] Reconstruct shape_analysis (#63790)
* reconstruct shape_analysis * fix input value shape infer * fix merge bugs * fix concat and gather op InferSymbolicShape * fix merge bug * fix value_to_shape_or_data hash error and add some checks * fix set shape for null value * fix group op lazy infer * add IsStaticShape check * fix merge bug * support static dim check and set for VectorType * change auto to detail type
1 parent c0b2c7d commit 36590f6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1053
-857
lines changed

paddle/cinn/hlir/dialect/operator/ir/manual_op.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) {
116116
}
117117

118118
bool GroupOp::InferSymbolicShape(
119-
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
120-
::pir::InferSymExprForBlock(*block(), shape_analysis);
119+
::pir::InferSymbolicShapeContext* infer_context) {
120+
::pir::InferSymExprForBlock(*block(), infer_context);
121121

122122
for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
123123
auto inner_yield_value = block()->back().operand_source(rst_idx);
124124
const auto& shape =
125-
shape_analysis->GetShapeOrDataForValue(inner_yield_value);
126-
shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape);
125+
infer_context->GetShapeOrDataForValue(inner_yield_value);
126+
infer_context->SetShapeOrDataForValue(result(rst_idx), shape);
127127
}
128128

129129
if (VLOG_IS_ON(4)) {
@@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder,
204204
void YieldStoreOp::VerifySig() {}
205205

206206
bool YieldStoreOp::InferSymbolicShape(
207-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
208-
shape_analysis->SetShapeOrDataForValue(
209-
result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0)));
207+
pir::InferSymbolicShapeContext* infer_context) {
208+
infer_context->SetShapeOrDataForValue(
209+
result(0), infer_context->GetShapeOrDataForValue(operand_source(0)));
210210
return true;
211211
}
212212

213213
bool ConcatOp::InferSymbolicShape(
214-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
214+
pir::InferSymbolicShapeContext* infer_context) {
215215
VLOG(4) << "Infer symbolic shape for cinn_op.concat";
216-
return ConcatOpInferSymbolicShape(this->operation(), shape_analysis);
216+
return ConcatOpInferSymbolicShape(this->operation(), infer_context);
217217
}
218218

219219
void ConcatOp::Build(pir::Builder& builder, // NOLINT
@@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
476476
}
477477

478478
bool GenerateShapeOp::InferSymbolicShape(
479-
pir::ShapeConstraintIRAnalysis* shape_analysis) {
479+
pir::InferSymbolicShapeContext* infer_context) {
480480
const auto attr_dim_exprs = [&] {
481481
std::vector<symbol::DimExpr> dim_exprs{};
482482
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
@@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape(
505505
}();
506506
auto DimExprs4InputDim =
507507
[&](int input_idx) -> const symbol::ShapeOrDataDimExprs& {
508-
return shape_analysis->GetShapeOrDataForValue(
508+
return infer_context->GetShapeOrDataForValue(
509509
this->operand_source(input_idx));
510510
};
511511
auto DimExprs4SymbolName =
@@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape(
527527
symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{
528528
symbol::TensorShapeOrDataDimExprs(shape, substituted_dim_exprs)};
529529

530-
shape_analysis->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
530+
infer_context->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
531531

532532
return true;
533533
}

paddle/cinn/hlir/dialect/operator/ir/manual_op.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class IR_API GroupOp
5353
pir::Block *block() const;
5454
std::vector<pir::Operation *> GetOperators() const;
5555

56-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
56+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
5757

5858
void VerifySig();
5959
void Print(pir::IrPrinter &printer); // NOLINT
@@ -102,7 +102,7 @@ class IR_API YieldStoreOp
102102

103103
void VerifySig();
104104

105-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
105+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
106106
};
107107

108108
class IR_API ConcatOp
@@ -123,7 +123,7 @@ class IR_API ConcatOp
123123

124124
void VerifySig() const {}
125125

126-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
126+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
127127
};
128128

129129
class IR_API SplitOp : public pir::Op<SplitOp> {
@@ -177,7 +177,7 @@ class IR_API GenerateShapeOp
177177

178178
pir::Value out() { return result(0); }
179179

180-
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
180+
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
181181

182182
static pir::Attribute ConvertSymbolBindingsToAttribute(
183183
pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT

paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,9 @@ class AddYieldStoreInFusionOpPattern
4545
auto orignal_base = op->operand_source(i);
4646
op->operand(i).set_source(store_op.result(0));
4747

48-
if (shape_analysis.HasShapeOrDataForValue(orignal_base)) {
49-
shape_analysis.SetShapeOrDataForValue(
50-
store_op.result(0),
51-
shape_analysis.GetShapeOrDataForValue(orignal_base));
52-
}
48+
shape_analysis.SetShapeOrDataForValue(
49+
store_op.result(0),
50+
shape_analysis.GetShapeOrDataForValue(orignal_base));
5351
}
5452

5553
return true;

paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_pass.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ class BlockDimExprsAsserter {
144144
PADDLE_THROW(phi::errors::Unimplemented(
145145
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
146146
} else {
147-
bool infer_result = interface.InferSymbolicShape(shape_analysis.get());
147+
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
148+
// is done.
149+
bool infer_result = interface.InferSymbolicShape(
150+
shape_analysis->GetInferSymbolicShapeContext());
148151
PADDLE_ENFORCE_EQ(infer_result,
149152
true,
150153
::common::errors::PreconditionNotMet(

paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,9 @@ ::pir::GroupOpsVec CloneOps(
182182
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
183183

184184
for (size_t i = 0; i < op->num_results(); ++i) {
185-
if (shape_analysis.HasShapeOrDataForValue(op->result(i))) {
186-
shape_analysis.SetShapeOrDataForValue(
187-
new_op->result(i),
188-
shape_analysis.GetShapeOrDataForValue(op->result(i)));
189-
}
185+
shape_analysis.SetShapeOrDataForValue(
186+
new_op->result(i),
187+
shape_analysis.GetShapeOrDataForValue(op->result(i)));
190188
}
191189

192190
vec_new_op_list.push_back(new_op);
@@ -357,11 +355,9 @@ class CinnGroupClusterPattern
357355
// update ir mapping
358356
for (size_t i = 0; i < output_values.size(); ++i) {
359357
ir_mapping.Add(output_values[i], new_group_op->result(i));
360-
if (shape_analysis.HasShapeOrDataForValue(output_values[i])) {
361-
shape_analysis.SetShapeOrDataForValue(
362-
new_group_op->result(i),
363-
shape_analysis.GetShapeOrDataForValue(output_values[i]));
364-
}
358+
shape_analysis.SetShapeOrDataForValue(
359+
new_group_op->result(i),
360+
shape_analysis.GetShapeOrDataForValue(output_values[i]));
365361
}
366362
for (size_t i = 0; i < output_values.size(); ++i) {
367363
auto find_it = all_output_values.find(output_values[i]);

paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
3333
std::vector<int> shape = phi::vectorize<int>(
3434
output.type().dyn_cast<pir::DenseTensorType>().dims());
3535

36-
if (shape_analysis->HasShapeOrDataForValue(op->result(0))) {
37-
const auto& shape_info =
38-
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
39-
int temp_dim = -1;
40-
41-
for (size_t i = 0; i < shape_info.size(); ++i) {
42-
if (shape_info[i].isa<int64_t>()) {
43-
shape[i] = shape_info[i].Get<int64_t>();
44-
} else {
45-
shape[i] = temp_dim;
46-
temp_dim = 1;
47-
}
36+
const auto& shape_info =
37+
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
38+
int temp_dim = -1;
39+
40+
for (size_t i = 0; i < shape_info.size(); ++i) {
41+
if (shape_info[i].isa<int64_t>()) {
42+
shape[i] = shape_info[i].Get<int64_t>();
43+
} else {
44+
shape[i] = temp_dim;
45+
temp_dim = 1;
4846
}
4947
}
5048
return shape;

paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.cc

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,11 @@ bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
5353
if (has_dynamic_shape) {
5454
auto& shape_analysis =
5555
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
56-
if (shape_analysis.HasShapeOrDataForValue(input) &&
57-
shape_analysis.HasShapeOrDataForValue(output)) {
58-
auto input_sym_shape =
59-
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
60-
auto output_sym_shape =
61-
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
62-
return input_sym_shape == output_sym_shape;
63-
}
64-
return false;
56+
auto input_sym_shape =
57+
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
58+
auto output_sym_shape =
59+
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
60+
return input_sym_shape == output_sym_shape;
6561
}
6662
return GetDims(input) == GetDims(output);
6763
};

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,10 @@ void InferSymbolicShapeForSubgraph(
214214
auto infer_symbolic_shape_interface =
215215
op->dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
216216
if (infer_symbolic_shape_interface) {
217-
infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis);
217+
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
218+
// is done.
219+
infer_symbolic_shape_interface.InferSymbolicShape(
220+
shape_analysis->GetInferSymbolicShapeContext());
218221
} else {
219222
PADDLE_THROW(phi::errors::Unimplemented(
220223
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
@@ -348,7 +351,6 @@ bool ReplaceShapeOpsToGenerateShape(
348351
auto ShapeOrDataDimExprs4Value =
349352
[&shape_analysis](
350353
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
351-
CHECK(shape_analysis->HasShapeOrDataForValue(value));
352354
return shape_analysis->GetShapeOrDataForValue(value);
353355
};
354356
std::optional<pir::Value> opt_generated_shape =

paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ class DynamicToStaticConverter {
104104
}
105105

106106
bool Convert() {
107-
if (!IsSymbolFullyInfered()) {
108-
return false;
109-
}
110107
bool updated = false;
111108
VisitEachValue(fusion_op_, [&](pir::Value value) {
112109
updated |= UpdateValueShape(value);
@@ -116,16 +113,6 @@ class DynamicToStaticConverter {
116113
}
117114

118115
private:
119-
bool IsSymbolFullyInfered() {
120-
bool is_infered = true;
121-
VisitEachValue(fusion_op_, [&](pir::Value value) {
122-
if (!shape_analysis_->HasShapeOrDataForValue(value)) {
123-
is_infered = false;
124-
}
125-
});
126-
return is_infered;
127-
}
128-
129116
DimExpr4SymbolName InitDimExpr4SymbolName() {
130117
const auto* map = GetGlobalDynamicToStaticDimMap();
131118
CHECK(map->has_value());
@@ -178,7 +165,6 @@ class DynamicToStaticConverter {
178165

179166
bool UpdateValueShape(pir::Value value) {
180167
bool update = false;
181-
CHECK(shape_analysis_->HasShapeOrDataForValue(value));
182168
const auto& origin_shape = GetOriginValueShape(value);
183169
const auto& target_shape = GetTargetValueShape(value);
184170
PADDLE_ENFORCE_EQ(

paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,6 @@ struct StaticDimToDynamicConverter {
150150
&pir::ShapeAnalysisManager::Instance().Get(
151151
this->fusion_op->GetParentProgram());
152152
ForEachValue([&](pir::Value value) {
153-
CHECK(shape_analysis->HasShapeOrDataForValue(value));
154153
const auto& origin_shape = GetOriginValueShape(value);
155154
const auto& target_shape = GetTargetValueShape(
156155
shape_analysis->GetShapeOrDataForValue(value).shape());
@@ -369,26 +368,8 @@ struct StaticDimToDynamicConverter {
369368
pir::Value value,
370369
int64_t constant,
371370
const std::string& symbol) {
372-
if (shape_analysis->HasShapeOrDataForValue(value)) {
373-
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
374-
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
375-
} else {
376-
auto& dims = value.type().dyn_cast<::pir::DenseTensorType>().dims();
377-
const auto& int_dims = ::common::vectorize<int>(dims);
378-
std::vector<symbol::DimExpr> old{};
379-
for (int dim : int_dims) {
380-
old.emplace_back(static_cast<std::int64_t>(dim));
381-
}
382-
const auto& opt_exprs =
383-
ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
384-
if (opt_exprs.has_value()) {
385-
return opt_exprs.value();
386-
} else {
387-
return symbol::ShapeOrDataDimExprs{
388-
symbol::TensorShapeOrDataDimExprs(old)};
389-
}
390-
}
391-
PADDLE_THROW(phi::errors::Fatal("Dead code"));
371+
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
372+
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
392373
}
393374

394375
template <typename ConverterT>

paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,14 @@ void SimplifyDimExpr(pir::Operation* module_op) {
101101

102102
VisitEachOp(module_op, [&](pir::Operation& op) {
103103
VisitEachValue(op, [&](pir::Value value) {
104-
if (!shape_analysis->HasShapeOrDataForValue(value)) {
105-
VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for "
106-
"value of the op:"
107-
<< op.name();
108-
} else {
109-
const symbol::ShapeOrDataDimExprs& shape_or_data =
110-
shape_analysis->GetShapeOrDataForValue(value);
111-
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
112-
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
113-
SimplifyShapeOrData(shape_or_data);
114-
VLOG(8) << op.name()
115-
<< " simplified_shape_or_data: " << simplified_shape_or_data;
116-
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
117-
}
104+
const symbol::ShapeOrDataDimExprs& shape_or_data =
105+
shape_analysis->GetShapeOrDataForValue(value);
106+
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
107+
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
108+
SimplifyShapeOrData(shape_or_data);
109+
VLOG(8) << op.name()
110+
<< " simplified_shape_or_data: " << simplified_shape_or_data;
111+
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
118112
});
119113
if (op.num_results() > 0) {
120114
pir::shape::SetShapeAttrForOp(

paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,9 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
6464
for (size_t i = 0; i < fusion_op.num_results(); ++i) {
6565
rewriter.ReplaceAllUsesWith(fusion_op.result(i),
6666
paddle_op.value()->result(i));
67-
if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) {
68-
shape_analysis.SetShapeOrDataForValue(
69-
paddle_op.value()->result(i),
70-
shape_analysis.GetShapeOrDataForValue(fusion_op.result(i)));
71-
} else {
72-
LOG(WARNING) << "No shape_data for "
73-
<< fusion_op.result(i).defining_op()->name() << "_result_"
74-
<< i << ", this may cause error in dynamic shape";
75-
}
67+
shape_analysis.SetShapeOrDataForValue(
68+
paddle_op.value()->result(i),
69+
shape_analysis.GetShapeOrDataForValue(fusion_op.result(i)));
7670
}
7771

7872
rewriter.EraseOp(fusion_op);

paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
4646
for (auto* op : std::vector{x_shape_op, y_shape_op, shape_broadcast_op}) {
4747
auto infer_symbolic_shape_interface =
4848
op->dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
49-
infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis);
49+
// TODO(Hongqing-work): delete this after the shape analysis reconstruct is
50+
// done.
51+
infer_symbolic_shape_interface.InferSymbolicShape(
52+
shape_analysis->GetInferSymbolicShapeContext());
5053
}
5154
return shape_broadcast_op->result(0);
5255
}

0 commit comments

Comments
 (0)