|
15 | 15 | #pragma once
|
16 | 16 |
|
17 | 17 | #include "paddle/cinn/hlir/framework/pir/compilation_task.h"
|
| 18 | +#include "paddle/cinn/ir/utils/ir_copy.h" |
18 | 19 |
|
19 | 20 | #include "paddle/cinn/backends/codegen_device_util.h"
|
20 | 21 | #include "paddle/cinn/common/dim_expr_converter.h"
|
@@ -241,6 +242,74 @@ void CompilationTask::Lowering() {
|
241 | 242 |
|
242 | 243 | context_->broadcast_condition_ = ChangeBroadcastConditionToExpr();
|
243 | 244 | }
|
| 245 | + |
| 246 | + auto SimplifyPredicate = [](GroupCompilationContext* context) { |
| 247 | + for (auto& expr : context->predicates_) { |
| 248 | + optim::SimplifyLogical(&expr); |
| 249 | + } |
| 250 | + if (context->broadcast_condition_.defined()) |
| 251 | + optim::SimplifyLogical(&context->broadcast_condition_); |
| 252 | + for (auto& expr : context->CX86_predicates_) { |
| 253 | + optim::SimplifyLogical(&expr); |
| 254 | + } |
| 255 | + }; |
| 256 | + |
| 257 | + // remove unreachable predicates. |
| 258 | + auto RemoveUnreachPredicate = [](GroupCompilationContext* context) { |
| 259 | + // remove unreachable predicate. |
| 260 | + std::vector<ir::Expr> new_predicates; |
| 261 | + std::vector<int> new_priorities; |
| 262 | + std::vector<ir::LoweredFunc> new_lowered_funcs; |
| 263 | + bool has_true_predicate = false; |
| 264 | + for (size_t i = 0; i < context->predicates_.size(); ++i) { |
| 265 | + if (has_true_predicate) continue; |
| 266 | + if (common::IsZero(context->predicates_[i])) continue; |
| 267 | + if (common::IsOne(context->predicates_[i])) has_true_predicate = true; |
| 268 | + new_predicates.push_back(context->predicates_[i]); |
| 269 | + new_priorities.push_back(context->priorities_[i]); |
| 270 | + new_lowered_funcs.push_back(context->lowered_funcs_[i]); |
| 271 | + } |
| 272 | + // CINN does not support returning an empty module now. if all predicates |
| 273 | + // are false, we push the first predicate as result. |
| 274 | + if (new_predicates.empty() && !context->predicates_.empty()) { |
| 275 | + new_predicates.push_back(context->predicates_[0]); |
| 276 | + new_priorities.push_back(context->priorities_[0]); |
| 277 | + new_lowered_funcs.push_back(context->lowered_funcs_[0]); |
| 278 | + } |
| 279 | + context->predicates_ = std::move(new_predicates); |
| 280 | + context->priorities_ = std::move(new_priorities); |
| 281 | + context->lowered_funcs_ = std::move(new_lowered_funcs); |
| 282 | + |
| 283 | + // remove unreachable CX86 predicate. |
| 284 | + std::vector<ir::Expr> new_CX86_predicates; |
| 285 | + std::vector<ir::LoweredFunc> new_CX86_lowered_funcs; |
| 286 | + bool has_true_CX86_predicate = false; |
| 287 | + for (size_t i = 0; i < context->CX86_predicates_.size(); ++i) { |
| 288 | + if (has_true_CX86_predicate) continue; |
| 289 | + if (common::IsZero(context->CX86_predicates_[i])) continue; |
| 290 | + if (common::IsOne(context->CX86_predicates_[i])) |
| 291 | + has_true_CX86_predicate = true; |
| 292 | + new_CX86_predicates.push_back(context->CX86_predicates_[i]); |
| 293 | + new_CX86_lowered_funcs.push_back(context->CX86_lowered_funcs_[i]); |
| 294 | + } |
| 295 | + // CINN does not support returning an empty module now. if all predicates |
| 296 | + // are false, we push the first predicate as result. |
| 297 | + if (new_CX86_predicates.empty() && !context->CX86_predicates_.empty()) { |
| 298 | + new_CX86_predicates.push_back(context->CX86_predicates_[0]); |
| 299 | + new_CX86_lowered_funcs.push_back(context->CX86_lowered_funcs_[0]); |
| 300 | + } |
| 301 | + context->CX86_predicates_ = std::move(new_CX86_predicates); |
| 302 | + context->CX86_lowered_funcs_ = std::move(new_CX86_lowered_funcs); |
| 303 | + }; |
| 304 | + // Logical Simplifysimplify predicates, such as: |
| 305 | + // false && ... ==> false |
| 306 | + // true || ... ==> true |
| 307 | + // 1 <= 1 ==> true |
| 308 | + SimplifyPredicate(context_); |
| 309 | + // Remove unreachable predicates, unreachable predicates means that predicate |
| 310 | + // is false, or a true predicate already existed before. |
| 311 | + RemoveUnreachPredicate(context_); |
| 312 | + |
244 | 313 | VLOG(5) << "End to lowering: " << context_->PrintPredicate2Funcs();
|
245 | 314 | }
|
246 | 315 |
|
|
0 commit comments