Skip to content

Commit 6e0955a

Browse files
authored
【CINN】Simplify predicate and remove unreachable branch (#72316)
* simplify predicate after lowering * remove useless comment * remove useless comment * fix undefine bug * remove unreachable branch
1 parent d023362 commit 6e0955a

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

paddle/cinn/hlir/framework/pir/compilation_task.cc

+69
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "paddle/cinn/hlir/framework/pir/compilation_task.h"
18+
#include "paddle/cinn/ir/utils/ir_copy.h"
1819

1920
#include "paddle/cinn/backends/codegen_device_util.h"
2021
#include "paddle/cinn/common/dim_expr_converter.h"
@@ -241,6 +242,74 @@ void CompilationTask::Lowering() {
241242

242243
context_->broadcast_condition_ = ChangeBroadcastConditionToExpr();
243244
}
245+
246+
auto SimplifyPredicate = [](GroupCompilationContext* context) {
247+
for (auto& expr : context->predicates_) {
248+
optim::SimplifyLogical(&expr);
249+
}
250+
if (context->broadcast_condition_.defined())
251+
optim::SimplifyLogical(&context->broadcast_condition_);
252+
for (auto& expr : context->CX86_predicates_) {
253+
optim::SimplifyLogical(&expr);
254+
}
255+
};
256+
257+
// remove unreachable predicates.
258+
auto RemoveUnreachPredicate = [](GroupCompilationContext* context) {
259+
// remove unreachable predicate.
260+
std::vector<ir::Expr> new_predicates;
261+
std::vector<int> new_priorities;
262+
std::vector<ir::LoweredFunc> new_lowered_funcs;
263+
bool has_true_predicate = false;
264+
for (size_t i = 0; i < context->predicates_.size(); ++i) {
265+
if (has_true_predicate) continue;
266+
if (common::IsZero(context->predicates_[i])) continue;
267+
if (common::IsOne(context->predicates_[i])) has_true_predicate = true;
268+
new_predicates.push_back(context->predicates_[i]);
269+
new_priorities.push_back(context->priorities_[i]);
270+
new_lowered_funcs.push_back(context->lowered_funcs_[i]);
271+
}
272+
// CINN does not support returning an empty module now. if all predicates
273+
// are false, we push the first predicate as result.
274+
if (new_predicates.empty() && !context->predicates_.empty()) {
275+
new_predicates.push_back(context->predicates_[0]);
276+
new_priorities.push_back(context->priorities_[0]);
277+
new_lowered_funcs.push_back(context->lowered_funcs_[0]);
278+
}
279+
context->predicates_ = std::move(new_predicates);
280+
context->priorities_ = std::move(new_priorities);
281+
context->lowered_funcs_ = std::move(new_lowered_funcs);
282+
283+
// remove unreachable CX86 predicate.
284+
std::vector<ir::Expr> new_CX86_predicates;
285+
std::vector<ir::LoweredFunc> new_CX86_lowered_funcs;
286+
bool has_true_CX86_predicate = false;
287+
for (size_t i = 0; i < context->CX86_predicates_.size(); ++i) {
288+
if (has_true_CX86_predicate) continue;
289+
if (common::IsZero(context->CX86_predicates_[i])) continue;
290+
if (common::IsOne(context->CX86_predicates_[i]))
291+
has_true_CX86_predicate = true;
292+
new_CX86_predicates.push_back(context->CX86_predicates_[i]);
293+
new_CX86_lowered_funcs.push_back(context->CX86_lowered_funcs_[i]);
294+
}
295+
// CINN does not support returning an empty module now. if all predicates
296+
// are false, we push the first predicate as result.
297+
if (new_CX86_predicates.empty() && !context->CX86_predicates_.empty()) {
298+
new_CX86_predicates.push_back(context->CX86_predicates_[0]);
299+
new_CX86_lowered_funcs.push_back(context->CX86_lowered_funcs_[0]);
300+
}
301+
context->CX86_predicates_ = std::move(new_CX86_predicates);
302+
context->CX86_lowered_funcs_ = std::move(new_CX86_lowered_funcs);
303+
};
304+
// Logical Simplifysimplify predicates, such as:
305+
// false && ... ==> false
306+
// true || ... ==> true
307+
// 1 <= 1 ==> true
308+
SimplifyPredicate(context_);
309+
// Remove unreachable predicates, unreachable predicates means that predicate
310+
// is false, or a true predicate already existed before.
311+
RemoveUnreachPredicate(context_);
312+
244313
VLOG(5) << "End to lowering: " << context_->PrintPredicate2Funcs();
245314
}
246315

paddle/cinn/optim/ir_simplify.cc

+12-2
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,13 @@ struct SimplifyLogicalMutator : public ir::IRMutator<> {
256256
*expr = Expr(false);
257257
return;
258258
}
259-
if (common::IsOne(node->a()) && common::IsOne(node->b()))
259+
if (common::IsOne(node->a()) && common::IsOne(node->b())) {
260260
*expr = Expr(true);
261+
} else if (common::IsOne(node->a())) {
262+
*expr = node->b();
263+
} else if (common::IsOne(node->b())) {
264+
*expr = node->a();
265+
}
261266
VLOG(7) << "End Visit And op: " << *expr;
262267
}
263268

@@ -276,8 +281,13 @@ struct SimplifyLogicalMutator : public ir::IRMutator<> {
276281
VLOG(7) << "End visit Or op: " << *expr;
277282
return;
278283
}
279-
if (common::IsZero(node->a()) && common::IsZero(node->b()))
284+
if (common::IsZero(node->a()) && common::IsZero(node->b())) {
280285
*expr = Expr(false);
286+
} else if (common::IsZero(node->a())) {
287+
*expr = node->b();
288+
} else if (common::IsZero(node->b())) {
289+
*expr = node->a();
290+
}
281291
VLOG(7) << "End visit Or op: " << *expr;
282292
}
283293

0 commit comments

Comments
 (0)