@@ -315,21 +315,21 @@ bool IfOp::InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis) {
315
315
// infer false block
316
316
pir::InferSymExprForBlock (false_block (), shape_analysis);
317
317
318
+ auto GetSymExprForBlockResult =
319
+ [shape_analysis](const pir::Operation &op,
320
+ uint32_t idx) -> const std::vector<symbol::DimExpr> & {
321
+ const auto &shape_or_data =
322
+ shape_analysis->GetShapeOrDataForValue (op.operand_source (idx));
323
+ if (shape_or_data.data ().has_value ()) {
324
+ return shape_or_data.data ().value ();
325
+ } else {
326
+ return shape_or_data.shape ();
327
+ }
328
+ };
329
+
318
330
// TODO(lanxianghit): for llama, `if` op's result num always > 0, but
319
331
// result_num == 0 should be supported in future
320
332
if (num_results () > 0 ) {
321
- auto GetSymExprForBlockResult =
322
- [shape_analysis](const pir::Operation &op,
323
- uint32_t idx) -> const std::vector<symbol::DimExpr> & {
324
- const auto &shape_or_data =
325
- shape_analysis->GetShapeOrDataForValue (op.operand_source (idx));
326
- if (shape_or_data.data ().has_value ()) {
327
- return shape_or_data.data ().value ();
328
- } else {
329
- return shape_or_data.shape ();
330
- }
331
- };
332
-
333
333
for (uint32_t rst_idx = 0 ; rst_idx < num_results (); rst_idx++) {
334
334
const auto &true_dims =
335
335
GetSymExprForBlockResult (true_block ().back (), rst_idx);
0 commit comments