Skip to content

Commit 91cfecf

Browse files
committed
format codes
1 parent a0517fa commit 91cfecf

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc

+12-12
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,21 @@ bool IfOp::InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis) {
315315
// infer false block
316316
pir::InferSymExprForBlock(false_block(), shape_analysis);
317317

318+
auto GetSymExprForBlockResult =
319+
[shape_analysis](const pir::Operation &op,
320+
uint32_t idx) -> const std::vector<symbol::DimExpr> & {
321+
const auto &shape_or_data =
322+
shape_analysis->GetShapeOrDataForValue(op.operand_source(idx));
323+
if (shape_or_data.data().has_value()) {
324+
return shape_or_data.data().value();
325+
} else {
326+
return shape_or_data.shape();
327+
}
328+
};
329+
318330
// TODO(lanxianghit): for llama, `if` op's result num always > 0, but
319331
// result_num == 0 should be supported in future
320332
if (num_results() > 0) {
321-
auto GetSymExprForBlockResult =
322-
[shape_analysis](const pir::Operation &op,
323-
uint32_t idx) -> const std::vector<symbol::DimExpr> & {
324-
const auto &shape_or_data =
325-
shape_analysis->GetShapeOrDataForValue(op.operand_source(idx));
326-
if (shape_or_data.data().has_value()) {
327-
return shape_or_data.data().value();
328-
} else {
329-
return shape_or_data.shape();
330-
}
331-
};
332-
333333
for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
334334
const auto &true_dims =
335335
GetSymExprForBlockResult(true_block().back(), rst_idx);

0 commit comments

Comments
 (0)