@@ -47,11 +47,26 @@ bool IsWarpReduce(const ScheduleConfig& config) {
47
47
return std::visit (MatchWarpReduce, config.tile_config .reduce_method );
48
48
}
49
49
50
+ bool UseReduceTile (const ScheduleConfig& config) {
51
+ const auto & raw_reduce_axis = config.base_info ->raw_reduce_axis ;
52
+ const auto raw_data_rank = config.base_info ->raw_data_rank ;
53
+ if (raw_reduce_axis.empty ()) {
54
+ return false ;
55
+ }
56
+ for (size_t i = 1 ; i < raw_reduce_axis.size (); i++) {
57
+ if (raw_reduce_axis[i] != raw_reduce_axis[i - 1 ] + 1 ) {
58
+ return false ;
59
+ }
60
+ }
61
+ return raw_reduce_axis.back () + 1 == raw_data_rank;
62
+ }
63
+
50
64
class TileFirstGeneralTactic final : public ScheduleTactic {
51
65
public:
52
66
void Init (ScheduleContext* context) override ;
53
67
54
68
void Apply (ir::IRSchedule* sch, const std::string& block_id) override ;
69
+ void ApplyReduceTile (ir::IRSchedule* sch, const std::string& block_id);
55
70
56
71
std::string TacticName () const override { return " TileFirstGeneralTactic" ; }
57
72
@@ -98,6 +113,11 @@ void TileFirstGeneralTactic::Init(ScheduleContext* context) {
98
113
99
114
void TileFirstGeneralTactic::Apply (ir::IRSchedule* sch,
100
115
const std::string& block_id) {
116
+ if (UseReduceTile (context_->config )) {
117
+ VLOG (4 ) << " Using ApplyReduceTile" ;
118
+ ApplyReduceTile (sch, block_id);
119
+ return ;
120
+ }
101
121
if (ir::IsReduceInitTensorName (block_id)) return ;
102
122
MergeReduceAxis (sch, block_id);
103
123
VLOG (6 ) << " After MergeReduceAxis on block: [" << block_id
@@ -136,6 +156,106 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
136
156
SetReduceType (sch, block_id);
137
157
}
138
158
159
+ void TileFirstGeneralTactic::ApplyReduceTile (ir::IRSchedule* sch,
160
+ const std::string& block_id) {
161
+ if (ir::IsReduceInitTensorName (block_id)) return ;
162
+
163
+ const auto sp_thread = context_->config .tile_config .warp_num * 32 /
164
+ context_->config .tile_config .tree_reduce_num ;
165
+ const auto sp_loop = context_->config .tile_config .spatial_inner_num ;
166
+ const auto rd_thread = context_->config .tile_config .tree_reduce_num ;
167
+ VLOG (4 ) << " ApplyReduceTile sp_thread=" << sp_thread;
168
+ VLOG (4 ) << " ApplyReduceTile sp_loop=" << sp_loop;
169
+ VLOG (4 ) << " ApplyReduceTile rd_thread=" << rd_thread;
170
+ VLOG (4 ) << " ApplyReduceTile vec_flatten_axis: "
171
+ << utils::Join (vec_flatten_axis_, " , " );
172
+ VLOG (4 ) << " ApplyReduceTile vec_reduce_axis: "
173
+ << utils::Join (vec_reduce_axis_, " , " );
174
+
175
+ // Merge reduce axes
176
+ MergeReduceAxis (sch, block_id);
177
+ VLOG (4 ) << " After MergeReduceAxis on block: [" << block_id
178
+ << " ], loop nest:\n "
179
+ << sch->GetModule ().GetExprs ().front ();
180
+
181
+ // Merge spatial axes
182
+ MergeFlattenAxis (sch, block_id);
183
+ VLOG (4 ) << " After MergeFlattenAxis on block: [" << block_id
184
+ << " ], loop nest:\n "
185
+ << sch->GetModule ().GetExprs ().front ();
186
+
187
+ // Split spatial axes -> [sp_block, sp_loop, sp_thread]
188
+ int current_reduce_axis = 0 ;
189
+ if (vec_flatten_axis_.size () > 0 ) {
190
+ auto loops = sch->GetLoops (block_id);
191
+ if (sp_loop > 1 && sp_thread > 1 ) {
192
+ sch->Split (loops[0 ], {-1 , sp_loop, sp_thread});
193
+ current_reduce_axis = 3 ;
194
+ } else if (sp_loop > 1 || sp_thread > 1 ) {
195
+ sch->Split (loops[0 ], {-1 , sp_loop > 1 ? sp_loop : sp_thread});
196
+ current_reduce_axis = 2 ;
197
+ } else {
198
+ current_reduce_axis = 1 ;
199
+ }
200
+ }
201
+ VLOG (4 ) << " After SplitSptial on block: [" << block_id << " ], loop nest:\n "
202
+ << sch->GetModule ().GetExprs ().front ();
203
+
204
+ // Split reduce axes -> [rd_loop, rd_thread]
205
+ if (vec_reduce_axis_.size () > 0 ) {
206
+ auto loops = sch->GetLoops (block_id);
207
+ auto reduce_loop = loops[current_reduce_axis].As <ir::For>();
208
+ sch->Split (loops[current_reduce_axis], {-1 , rd_thread});
209
+ VLOG (4 ) << " Before ReorderReduction on block: [" << block_id
210
+ << " ], loop nest:\n "
211
+ << sch->GetModule ().GetExprs ().front ();
212
+
213
+ // TODO(lshpku): the Reorder is unneeded if the later FactorizeReduction
214
+ // supports rf_axis=1.
215
+ loops = sch->GetLoops (block_id);
216
+ sch->Reorder ({loops[current_reduce_axis + 1 ], loops[current_reduce_axis]});
217
+ VLOG (4 ) << " Before FactorizeReduction on block: [" << block_id
218
+ << " ], loop nest:\n "
219
+ << sch->GetModule ().GetExprs ().front ();
220
+
221
+ if (IsReduceBlock (context_->config , block_id)) {
222
+ loops = sch->GetLoops (block_id);
223
+ sch->FactorizeReduction (loops[current_reduce_axis],
224
+ /* rf_axis = */ 0 ,
225
+ /* with_write_back_block_init = */ false );
226
+ }
227
+ }
228
+ VLOG (4 ) << " After SplitReduce on block: [" << block_id << " ], loop nest:\n "
229
+ << sch->GetModule ().GetExprs ().front ();
230
+
231
+ // Bind CUDA info
232
+ const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
233
+ std::string sp_axis_type = " threadIdx.y" ;
234
+ std::string rd_axis_type = " threadIdx.x" ;
235
+ sch->Bind (loops[0 ], " blockIdx.x" );
236
+ if (!vec_flatten_axis_.empty () && sp_thread > 1 ) {
237
+ if (vec_reduce_axis_.empty ()) {
238
+ sch->Bind (loops[current_reduce_axis - 1 ], rd_axis_type);
239
+ } else {
240
+ sch->Bind (loops[current_reduce_axis - 1 ], sp_axis_type);
241
+ }
242
+ }
243
+ if (!vec_reduce_axis_.empty () && current_reduce_axis > 0 ) {
244
+ sch->Bind (loops[current_reduce_axis], rd_axis_type);
245
+ }
246
+ };
247
+ DoBind (sch->GetLoops (block_id));
248
+ if (IsReduceBlock (context_->config , block_id) &&
249
+ sch->HasBlock (block_id + " _rf" )) {
250
+ DoBind (sch->GetLoops (block_id + " _rf" ));
251
+ }
252
+ VLOG (4 ) << " After BindCudaInfo on block: [" << block_id << " ], loop nest:\n "
253
+ << sch->GetModule ().GetExprs ().front ();
254
+
255
+ VariableTypeAssignment (sch, block_id);
256
+ SetReduceType (sch, block_id);
257
+ }
258
+
139
259
void TileFirstGeneralTactic::MergeFlattenAxis (ir::IRSchedule* sch,
140
260
const std::string& block_id) {
141
261
if (vec_flatten_axis_.size () >= 2 ) {
0 commit comments