@@ -61,8 +61,7 @@ bool IsProhibitScheduleExternCallBlock(ir::Expr block) {
61
61
sch_block->body , [&](const Expr* x) { return x->As <ir::Call>(); });
62
62
for (ir::Expr call : find_call) {
63
63
ir::Call* call_node = call.As <ir::Call>();
64
- if (call.As <ir::Call>() && kProhibitScheduleExternalFuncNames .count (
65
- call.As <ir::Call>()->name ) != 0 ) {
64
+ if (kProhibitScheduleExternalFuncNames .count (call_node->name ) != 0 ) {
66
65
return true ;
67
66
}
68
67
}
@@ -72,16 +71,14 @@ bool IsProhibitScheduleExternCallBlock(ir::Expr block) {
72
71
// Find loops with same extents of 2 ScheduleBlock
73
72
std::vector<std::tuple<ir::Expr, ir::Expr>> FindSameOuterLoops (
74
73
ir::ScheduleBlockNode* source_node, ir::ScheduleBlockNode* target_node) {
75
- std::vector<ir::Expr> src_ctrl_stmts = source_node->ControlStmts ();
76
- std::vector<ir::Expr> tgt_ctrl_stmts = target_node->ControlStmts ();
74
+ std::vector<ir::Expr> src_loops = source_node->GetLoops ();
75
+ std::vector<ir::Expr> tgt_loops = target_node->GetLoops ();
77
76
std::vector<std::tuple<ir::Expr, ir::Expr>> same_loops;
78
- int min_stmt_size = std::min (src_ctrl_stmts .size (), tgt_ctrl_stmts .size ());
77
+ int min_stmt_size = std::min (src_loops .size (), tgt_loops .size ());
79
78
for (int i = 0 ; i < min_stmt_size; ++i) {
80
- if (src_ctrl_stmts[i].As <ir::For>() && tgt_ctrl_stmts[i].As <ir::For>() &&
81
- ir::GetLoopExtent (src_ctrl_stmts[i]) ==
82
- GetLoopExtent (tgt_ctrl_stmts[i])) {
83
- same_loops.push_back (
84
- std::make_tuple (src_ctrl_stmts[i], tgt_ctrl_stmts[i]));
79
+ if (src_loops[i].As <ir::For>() && tgt_loops[i].As <ir::For>() &&
80
+ GetLoopExtent (src_loops[i]) == GetLoopExtent (tgt_loops[i])) {
81
+ same_loops.push_back (std::make_tuple (src_loops[i], tgt_loops[i]));
85
82
} else {
86
83
break ;
87
84
}
@@ -93,8 +90,10 @@ std::vector<std::tuple<ir::Expr, ir::Expr>> FindSameOuterLoops(
93
90
std::unordered_set<std::string> GetReduceLoopVarNames (ir::Expr block) {
94
91
ir::ScheduleBlockRealize* schedule_block_realize =
95
92
block.As <ir::ScheduleBlockRealize>();
93
+ CHECK_NOTNULL (schedule_block_realize);
96
94
ir::ScheduleBlock* schedule_block =
97
95
schedule_block_realize->schedule_block .As <ir::ScheduleBlock>();
96
+ CHECK_NOTNULL (schedule_block);
98
97
std::vector<ir::Expr> iter_values = schedule_block_realize->iter_values ;
99
98
std::vector<ir::Var> iter_vars = schedule_block->iter_vars ;
100
99
std::unordered_set<std::string> reduce_loop_var_names;
@@ -115,9 +114,11 @@ std::unordered_set<std::string> GetReduceLoopVarNames(ir::Expr block) {
115
114
std::unordered_set<std::string> GetReduceVarNames (ir::Expr block) {
116
115
ir::ScheduleBlockRealize* schedule_block_realize =
117
116
block.As <ir::ScheduleBlockRealize>();
117
+ CHECK_NOTNULL (schedule_block_realize);
118
118
ir::ScheduleBlock* schedule_block =
119
119
schedule_block_realize->schedule_block .As <ir::ScheduleBlock>();
120
- std::vector<ir::Var> iter_vars = schedule_block->iter_vars ;
120
+ CHECK_NOTNULL (schedule_block);
121
+ std::vector<ir::Var>& iter_vars = schedule_block->iter_vars ;
121
122
std::unordered_set<std::string> reduce_var_names;
122
123
for (int i = 0 ; i < iter_vars.size (); ++i) {
123
124
if (iter_vars[i]->is_reduce_axis ) {
@@ -162,21 +163,21 @@ NodePriority StaticShapeGroupScheduler::CalculateNodePriority(
162
163
GetReduceLoopVarNames (node->Block ());
163
164
164
165
int64_t reduce_score = 1 ;
165
- double score = 1 ;
166
- for (Expr expr : node->ControlStmts ()) {
166
+ int64_t score = 1 ;
167
+ for (Expr expr : node->GetLoops ()) {
167
168
ir::For* for_node = expr.As <ir::For>();
168
- if (for_node != nullptr ) {
169
- score * = ir::GetLoopExtent (expr);
170
- }
169
+ CHECK_NOTNULL (for_node);
170
+ int loop_extent = ir::GetLoopExtent (expr);
171
+ score *= loop_extent;
171
172
if (reduce_loop_var_names.count (for_node->loop_var ->name ) != 0 ) {
172
- reduce_score *= ir::GetLoopExtent (expr) ;
173
+ reduce_score *= loop_extent ;
173
174
}
174
175
if (for_node->is_binded ()) {
175
176
has_loop_binded = true ;
176
177
}
177
178
}
178
179
if (reduce_score > 1 ) {
179
- score *= (reduce_score * std::log2 (reduce_score) );
180
+ score = std::numeric_limits< int64_t >:: max ( );
180
181
}
181
182
182
183
VLOG (6 ) << " The priority score of node " << node->id () << " is " << score;
@@ -239,13 +240,12 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
239
240
[&](const ir::Expr* x) {
240
241
bool find_reduce_var = false ;
241
242
if (x->As <ir::Load>()) {
242
- int i = 0 ;
243
243
for (ir::Expr index : x->As <ir::Load>()->indices ) {
244
244
if (index.as_var () &&
245
245
reduce_var_names.count (index.as_var_ref ()->name ) > 0 ) {
246
246
find_reduce_var = true ;
247
+ break ;
247
248
}
248
- ++i;
249
249
}
250
250
}
251
251
return find_reduce_var;
@@ -325,7 +325,7 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
325
325
return false ;
326
326
}
327
327
328
- for (ir::Expr expr : node->ControlStmts ()) {
328
+ for (ir::Expr expr : node->GetLoops ()) {
329
329
if (expr.As <ir::For>() != nullptr &&
330
330
(expr.As <ir::For>()->for_type () == ir::ForType::GPUBlock ||
331
331
expr.As <ir::For>()->for_type () == ir::ForType::GPUThread)) {
@@ -341,7 +341,7 @@ void StaticShapeGroupScheduler::DoLoopAlignment() {
341
341
<< " with block: " << global_master->id ();
342
342
343
343
// 1. Fuse source loops
344
- ir::Expr source_loop = ir_sch_->Fuse (node->ControlStmts ());
344
+ ir::Expr source_loop = ir_sch_->Fuse (node->GetLoops ());
345
345
int total_source_extent = ir::GetLoopExtent (source_loop);
346
346
347
347
// 2. Split source loop to align with the target loops
@@ -475,11 +475,11 @@ void StaticShapeGroupScheduler::DoVerticalLoopFusion() {
475
475
ir::Expr target_loop;
476
476
bool find_target_loop = false ;
477
477
// Collect infomation of original loops
478
- std::vector<ir::Expr> original_ctrl_stmts = node->ControlStmts ();
478
+ std::vector<ir::Expr> original_loops = node->GetLoops ();
479
479
int64_t original_total_loop_extent = 1 ;
480
480
std::vector<std::pair<std::string, int >> original_loop_infos;
481
481
std::unordered_set<ir::IrNode*> original_loop_node_ptrs;
482
- for (ir::Expr stmt : original_ctrl_stmts ) {
482
+ for (ir::Expr stmt : original_loops ) {
483
483
if (stmt.As <ir::For>()) {
484
484
int extent = ir::GetLoopExtent (stmt);
485
485
original_total_loop_extent *= extent;
@@ -550,15 +550,15 @@ void StaticShapeGroupScheduler::DoVerticalLoopFusion() {
550
550
if (find_target_loop) {
551
551
ir_sch_->SimpleComputeAt (node->Block (), target_loop);
552
552
VLOG (6 ) << " after compute at: " << ir_sch_->GetModule ().GetExprs ()[0 ];
553
- std::vector<ir::Expr> new_stmts = node->ControlStmts ();
553
+ std::vector<ir::Expr> new_loops = node->GetLoops ();
554
554
for (int idx = 0 ; idx < original_loop_infos.size (); ++idx) {
555
555
if (original_loop_infos[idx].first .empty ()) {
556
556
continue ;
557
557
}
558
- if (idx < new_stmts .size ()) {
559
- CHECK (new_stmts [idx].As <ir::For>());
560
- if (new_stmts [idx].As <ir::For>()->is_serial ()) {
561
- ir_sch_->Bind (new_stmts [idx], original_loop_infos[idx].first );
558
+ if (idx < new_loops .size ()) {
559
+ CHECK (new_loops [idx].As <ir::For>());
560
+ if (new_loops [idx].As <ir::For>()->is_serial ()) {
561
+ ir_sch_->Bind (new_loops [idx], original_loop_infos[idx].first );
562
562
}
563
563
} else {
564
564
ir::Expr unit_loop = ir_sch_->AddUnitLoop (node->Block ());
0 commit comments