Skip to content

Commit c294b45

Browse files
Aurelius84SigureMo
authored andcommitted
[CINN]Refine StaticShapeGroupScheduler code while learning logic (#59540)
* [CINN]Refine StaticShapeGroupScheduler code while learning logic * fix comment
1 parent 28587f1 commit c294b45

File tree

4 files changed

+32
-32
lines changed

4 files changed

+32
-32
lines changed

paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ bool IsProhibitScheduleExternCallBlock(ir::Expr block) {
6161
sch_block->body, [&](const Expr* x) { return x->As<ir::Call>(); });
6262
for (ir::Expr call : find_call) {
6363
ir::Call* call_node = call.As<ir::Call>();
64-
if (call.As<ir::Call>() && kProhibitScheduleExternalFuncNames.count(
65-
call.As<ir::Call>()->name) != 0) {
64+
if (kProhibitScheduleExternalFuncNames.count(call_node->name) != 0) {
6665
return true;
6766
}
6867
}
@@ -72,16 +71,14 @@ bool IsProhibitScheduleExternCallBlock(ir::Expr block) {
7271
// Find loops with same extents of 2 ScheduleBlock
7372
std::vector<std::tuple<ir::Expr, ir::Expr>> FindSameOuterLoops(
7473
ir::ScheduleBlockNode* source_node, ir::ScheduleBlockNode* target_node) {
75-
std::vector<ir::Expr> src_ctrl_stmts = source_node->ControlStmts();
76-
std::vector<ir::Expr> tgt_ctrl_stmts = target_node->ControlStmts();
74+
std::vector<ir::Expr> src_loops = source_node->GetLoops();
75+
std::vector<ir::Expr> tgt_loops = target_node->GetLoops();
7776
std::vector<std::tuple<ir::Expr, ir::Expr>> same_loops;
78-
int min_stmt_size = std::min(src_ctrl_stmts.size(), tgt_ctrl_stmts.size());
77+
int min_stmt_size = std::min(src_loops.size(), tgt_loops.size());
7978
for (int i = 0; i < min_stmt_size; ++i) {
80-
if (src_ctrl_stmts[i].As<ir::For>() && tgt_ctrl_stmts[i].As<ir::For>() &&
81-
ir::GetLoopExtent(src_ctrl_stmts[i]) ==
82-
GetLoopExtent(tgt_ctrl_stmts[i])) {
83-
same_loops.push_back(
84-
std::make_tuple(src_ctrl_stmts[i], tgt_ctrl_stmts[i]));
79+
if (src_loops[i].As<ir::For>() && tgt_loops[i].As<ir::For>() &&
80+
GetLoopExtent(src_loops[i]) == GetLoopExtent(tgt_loops[i])) {
81+
same_loops.push_back(std::make_tuple(src_loops[i], tgt_loops[i]));
8582
} else {
8683
break;
8784
}
@@ -93,8 +90,10 @@ std::vector<std::tuple<ir::Expr, ir::Expr>> FindSameOuterLoops(
9390
std::unordered_set<std::string> GetReduceLoopVarNames(ir::Expr block) {
9491
ir::ScheduleBlockRealize* schedule_block_realize =
9592
block.As<ir::ScheduleBlockRealize>();
93+
CHECK_NOTNULL(schedule_block_realize);
9694
ir::ScheduleBlock* schedule_block =
9795
schedule_block_realize->schedule_block.As<ir::ScheduleBlock>();
96+
CHECK_NOTNULL(schedule_block);
9897
std::vector<ir::Expr> iter_values = schedule_block_realize->iter_values;
9998
std::vector<ir::Var> iter_vars = schedule_block->iter_vars;
10099
std::unordered_set<std::string> reduce_loop_var_names;
@@ -115,9 +114,11 @@ std::unordered_set<std::string> GetReduceLoopVarNames(ir::Expr block) {
115114
std::unordered_set<std::string> GetReduceVarNames(ir::Expr block) {
116115
ir::ScheduleBlockRealize* schedule_block_realize =
117116
block.As<ir::ScheduleBlockRealize>();
117+
CHECK_NOTNULL(schedule_block_realize);
118118
ir::ScheduleBlock* schedule_block =
119119
schedule_block_realize->schedule_block.As<ir::ScheduleBlock>();
120-
std::vector<ir::Var> iter_vars = schedule_block->iter_vars;
120+
CHECK_NOTNULL(schedule_block);
121+
std::vector<ir::Var>& iter_vars = schedule_block->iter_vars;
121122
std::unordered_set<std::string> reduce_var_names;
122123
for (int i = 0; i < iter_vars.size(); ++i) {
123124
if (iter_vars[i]->is_reduce_axis) {
@@ -162,21 +163,21 @@ NodePriority StaticShapeGroupScheduler::CalculateNodePriority(
162163
GetReduceLoopVarNames(node->Block());
163164

164165
int64_t reduce_score = 1;
165-
double score = 1;
166-
for (Expr expr : node->ControlStmts()) {
166+
int64_t score = 1;
167+
for (Expr expr : node->GetLoops()) {
167168
ir::For* for_node = expr.As<ir::For>();
168-
if (for_node != nullptr) {
169-
score *= ir::GetLoopExtent(expr);
170-
}
169+
CHECK_NOTNULL(for_node);
170+
int loop_extent = ir::GetLoopExtent(expr);
171+
score *= loop_extent;
171172
if (reduce_loop_var_names.count(for_node->loop_var->name) != 0) {
172-
reduce_score *= ir::GetLoopExtent(expr);
173+
reduce_score *= loop_extent;
173174
}
174175
if (for_node->is_binded()) {
175176
has_loop_binded = true;
176177
}
177178
}
178179
if (reduce_score > 1) {
179-
score *= (reduce_score * std::log2(reduce_score));
180+
score = std::numeric_limits<int64_t>::max();
180181
}
181182

182183
VLOG(6) << "The priority score of node " << node->id() << " is " << score;
@@ -239,13 +240,12 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
239240
[&](const ir::Expr* x) {
240241
bool find_reduce_var = false;
241242
if (x->As<ir::Load>()) {
242-
int i = 0;
243243
for (ir::Expr index : x->As<ir::Load>()->indices) {
244244
if (index.as_var() &&
245245
reduce_var_names.count(index.as_var_ref()->name) > 0) {
246246
find_reduce_var = true;
247+
break;
247248
}
248-
++i;
249249
}
250250
}
251251
return find_reduce_var;
@@ -325,7 +325,7 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
325325
return false;
326326
}
327327

328-
for (ir::Expr expr : node->ControlStmts()) {
328+
for (ir::Expr expr : node->GetLoops()) {
329329
if (expr.As<ir::For>() != nullptr &&
330330
(expr.As<ir::For>()->for_type() == ir::ForType::GPUBlock ||
331331
expr.As<ir::For>()->for_type() == ir::ForType::GPUThread)) {
@@ -341,7 +341,7 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
341341
<< " with block: " << global_master->id();
342342

343343
// 1. Fuse source loops
344-
ir::Expr source_loop = ir_sch_->Fuse(node->ControlStmts());
344+
ir::Expr source_loop = ir_sch_->Fuse(node->GetLoops());
345345
int total_source_extent = ir::GetLoopExtent(source_loop);
346346

347347
// 2. Split source loop to align with the target loops
@@ -475,11 +475,11 @@ void StaticShapeGroupScheduler::DoVerticalLoopFusion() {
475475
ir::Expr target_loop;
476476
bool find_target_loop = false;
477477
// Collect infomation of original loops
478-
std::vector<ir::Expr> original_ctrl_stmts = node->ControlStmts();
478+
std::vector<ir::Expr> original_loops = node->GetLoops();
479479
int64_t original_total_loop_extent = 1;
480480
std::vector<std::pair<std::string, int>> original_loop_infos;
481481
std::unordered_set<ir::IrNode*> original_loop_node_ptrs;
482-
for (ir::Expr stmt : original_ctrl_stmts) {
482+
for (ir::Expr stmt : original_loops) {
483483
if (stmt.As<ir::For>()) {
484484
int extent = ir::GetLoopExtent(stmt);
485485
original_total_loop_extent *= extent;
@@ -550,15 +550,15 @@ void StaticShapeGroupScheduler::DoVerticalLoopFusion() {
550550
if (find_target_loop) {
551551
ir_sch_->SimpleComputeAt(node->Block(), target_loop);
552552
VLOG(6) << "after compute at: " << ir_sch_->GetModule().GetExprs()[0];
553-
std::vector<ir::Expr> new_stmts = node->ControlStmts();
553+
std::vector<ir::Expr> new_loops = node->GetLoops();
554554
for (int idx = 0; idx < original_loop_infos.size(); ++idx) {
555555
if (original_loop_infos[idx].first.empty()) {
556556
continue;
557557
}
558-
if (idx < new_stmts.size()) {
559-
CHECK(new_stmts[idx].As<ir::For>());
560-
if (new_stmts[idx].As<ir::For>()->is_serial()) {
561-
ir_sch_->Bind(new_stmts[idx], original_loop_infos[idx].first);
558+
if (idx < new_loops.size()) {
559+
CHECK(new_loops[idx].As<ir::For>());
560+
if (new_loops[idx].As<ir::For>()->is_serial()) {
561+
ir_sch_->Bind(new_loops[idx], original_loop_infos[idx].first);
562562
}
563563
} else {
564564
ir::Expr unit_loop = ir_sch_->AddUnitLoop(node->Block());

paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace ir {
2323
// and secondly considering the amount of calculated data.
2424
struct NodePriority {
2525
bool has_loop_binded;
26-
double score;
26+
int64_t score;
2727

2828
bool operator<(const NodePriority& other) const {
2929
if (has_loop_binded ^ other.has_loop_binded) {

paddle/cinn/ir/schedule_block_graph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ScheduleBlockNode::ScheduleBlockNode(Expr block, const IRSchedule& ir_sch)
3232

3333
Expr ScheduleBlockNode::Block() const { return ir_sch_.GetBlock(id_); }
3434

35-
std::vector<Expr> ScheduleBlockNode::ControlStmts() const {
35+
std::vector<Expr> ScheduleBlockNode::GetLoops() const {
3636
return ir_sch_.GetLoops(id_);
3737
}
3838

paddle/cinn/ir/schedule_block_graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class ScheduleBlockNode : public cinn::common::GraphNode {
4040

4141
// Get all control stmts containing the schedule_block, now only the For node
4242
// is being considered.
43-
std::vector<Expr> ControlStmts() const;
43+
std::vector<Expr> GetLoops() const;
4444

4545
// Get all the upstream nodes that this node depends on.
4646
std::unordered_set<std::string> UpstreamNodes() const {

0 commit comments

Comments
 (0)