@@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) {
116
116
}
117
117
118
118
bool GroupOp::InferSymbolicShape (
119
- ::pir::ShapeConstraintIRAnalysis* shape_analysis ) {
120
- ::pir::InferSymExprForBlock (*block (), shape_analysis );
119
+ ::pir::InferSymbolicShapeContext* infer_context ) {
120
+ ::pir::InferSymExprForBlock (*block (), infer_context );
121
121
122
122
for (uint32_t rst_idx = 0 ; rst_idx < num_results(); rst_idx++) {
123
123
auto inner_yield_value = block ()->back ().operand_source (rst_idx);
124
124
const auto & shape =
125
- shape_analysis ->GetShapeOrDataForValue (inner_yield_value);
126
- shape_analysis ->SetShapeOrDataForValue (result (rst_idx), shape);
125
+ infer_context ->GetShapeOrDataForValue (inner_yield_value);
126
+ infer_context ->SetShapeOrDataForValue (result (rst_idx), shape);
127
127
}
128
128
129
129
if (VLOG_IS_ON (4 )) {
@@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder,
204
204
void YieldStoreOp::VerifySig () {}
205
205
206
206
bool YieldStoreOp::InferSymbolicShape (
207
- pir::ShapeConstraintIRAnalysis* shape_analysis ) {
208
- shape_analysis ->SetShapeOrDataForValue (
209
- result (0 ), shape_analysis ->GetShapeOrDataForValue (operand_source (0 )));
207
+ pir::InferSymbolicShapeContext* infer_context ) {
208
+ infer_context ->SetShapeOrDataForValue (
209
+ result (0 ), infer_context ->GetShapeOrDataForValue (operand_source (0 )));
210
210
return true ;
211
211
}
212
212
213
213
bool ConcatOp::InferSymbolicShape (
214
- pir::ShapeConstraintIRAnalysis* shape_analysis ) {
214
+ pir::InferSymbolicShapeContext* infer_context ) {
215
215
VLOG (4 ) << " Infer symbolic shape for cinn_op.concat" ;
216
- return ConcatOpInferSymbolicShape (this ->operation (), shape_analysis );
216
+ return ConcatOpInferSymbolicShape (this ->operation (), infer_context );
217
217
}
218
218
219
219
void ConcatOp::Build (pir::Builder& builder, // NOLINT
@@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
476
476
}
477
477
478
478
bool GenerateShapeOp::InferSymbolicShape (
479
- pir::ShapeConstraintIRAnalysis* shape_analysis ) {
479
+ pir::InferSymbolicShapeContext* infer_context ) {
480
480
const auto attr_dim_exprs = [&] {
481
481
std::vector<symbol::DimExpr> dim_exprs{};
482
482
pir::Attribute dim_expr_attr = this ->attributes ().at (" output_dim_exprs" );
@@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape(
505
505
}();
506
506
auto DimExprs4InputDim =
507
507
[&](int input_idx) -> const symbol::ShapeOrDataDimExprs& {
508
- return shape_analysis ->GetShapeOrDataForValue (
508
+ return infer_context ->GetShapeOrDataForValue (
509
509
this ->operand_source (input_idx));
510
510
};
511
511
auto DimExprs4SymbolName =
@@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape(
527
527
symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{
528
528
symbol::TensorShapeOrDataDimExprs (shape, substituted_dim_exprs)};
529
529
530
- shape_analysis ->SetShapeOrDataForValue (this ->out (), shape_or_data_dim_exprs);
530
+ infer_context ->SetShapeOrDataForValue (this ->out (), shape_or_data_dim_exprs);
531
531
532
532
return true ;
533
533
}
0 commit comments