Skip to content

Commit 72e8d4e

Browse files
authored
[CINN] New tiling method optimized for warp-level continuous read (PaddlePaddle#64240)
* add new tile apply function * replace index of local buffer
1 parent 252b746 commit 72e8d4e

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,26 @@ bool IsWarpReduce(const ScheduleConfig& config) {
4747
return std::visit(MatchWarpReduce, config.tile_config.reduce_method);
4848
}
4949

50+
bool UseReduceTile(const ScheduleConfig& config) {
51+
const auto& raw_reduce_axis = config.base_info->raw_reduce_axis;
52+
const auto raw_data_rank = config.base_info->raw_data_rank;
53+
if (raw_reduce_axis.empty()) {
54+
return false;
55+
}
56+
for (size_t i = 1; i < raw_reduce_axis.size(); i++) {
57+
if (raw_reduce_axis[i] != raw_reduce_axis[i - 1] + 1) {
58+
return false;
59+
}
60+
}
61+
return raw_reduce_axis.back() + 1 == raw_data_rank;
62+
}
63+
5064
class TileFirstGeneralTactic final : public ScheduleTactic {
5165
public:
5266
void Init(ScheduleContext* context) override;
5367

5468
void Apply(ir::IRSchedule* sch, const std::string& block_id) override;
69+
void ApplyReduceTile(ir::IRSchedule* sch, const std::string& block_id);
5570

5671
std::string TacticName() const override { return "TileFirstGeneralTactic"; }
5772

@@ -98,6 +113,11 @@ void TileFirstGeneralTactic::Init(ScheduleContext* context) {
98113

99114
void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
100115
const std::string& block_id) {
116+
if (UseReduceTile(context_->config)) {
117+
VLOG(4) << "Using ApplyReduceTile";
118+
ApplyReduceTile(sch, block_id);
119+
return;
120+
}
101121
if (ir::IsReduceInitTensorName(block_id)) return;
102122
MergeReduceAxis(sch, block_id);
103123
VLOG(6) << "After MergeReduceAxis on block: [" << block_id
@@ -136,6 +156,106 @@ void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch,
136156
SetReduceType(sch, block_id);
137157
}
138158

159+
void TileFirstGeneralTactic::ApplyReduceTile(ir::IRSchedule* sch,
160+
const std::string& block_id) {
161+
if (ir::IsReduceInitTensorName(block_id)) return;
162+
163+
const auto sp_thread = context_->config.tile_config.warp_num * 32 /
164+
context_->config.tile_config.tree_reduce_num;
165+
const auto sp_loop = context_->config.tile_config.spatial_inner_num;
166+
const auto rd_thread = context_->config.tile_config.tree_reduce_num;
167+
VLOG(4) << "ApplyReduceTile sp_thread=" << sp_thread;
168+
VLOG(4) << "ApplyReduceTile sp_loop=" << sp_loop;
169+
VLOG(4) << "ApplyReduceTile rd_thread=" << rd_thread;
170+
VLOG(4) << "ApplyReduceTile vec_flatten_axis: "
171+
<< utils::Join(vec_flatten_axis_, ", ");
172+
VLOG(4) << "ApplyReduceTile vec_reduce_axis: "
173+
<< utils::Join(vec_reduce_axis_, ", ");
174+
175+
// Merge reduce axes
176+
MergeReduceAxis(sch, block_id);
177+
VLOG(4) << "After MergeReduceAxis on block: [" << block_id
178+
<< "], loop nest:\n"
179+
<< sch->GetModule().GetExprs().front();
180+
181+
// Merge spatial axes
182+
MergeFlattenAxis(sch, block_id);
183+
VLOG(4) << "After MergeFlattenAxis on block: [" << block_id
184+
<< "], loop nest:\n"
185+
<< sch->GetModule().GetExprs().front();
186+
187+
// Split spatial axes -> [sp_block, sp_loop, sp_thread]
188+
int current_reduce_axis = 0;
189+
if (vec_flatten_axis_.size() > 0) {
190+
auto loops = sch->GetLoops(block_id);
191+
if (sp_loop > 1 && sp_thread > 1) {
192+
sch->Split(loops[0], {-1, sp_loop, sp_thread});
193+
current_reduce_axis = 3;
194+
} else if (sp_loop > 1 || sp_thread > 1) {
195+
sch->Split(loops[0], {-1, sp_loop > 1 ? sp_loop : sp_thread});
196+
current_reduce_axis = 2;
197+
} else {
198+
current_reduce_axis = 1;
199+
}
200+
}
201+
VLOG(4) << "After SplitSptial on block: [" << block_id << "], loop nest:\n"
202+
<< sch->GetModule().GetExprs().front();
203+
204+
// Split reduce axes -> [rd_loop, rd_thread]
205+
if (vec_reduce_axis_.size() > 0) {
206+
auto loops = sch->GetLoops(block_id);
207+
auto reduce_loop = loops[current_reduce_axis].As<ir::For>();
208+
sch->Split(loops[current_reduce_axis], {-1, rd_thread});
209+
VLOG(4) << "Before ReorderReduction on block: [" << block_id
210+
<< "], loop nest:\n"
211+
<< sch->GetModule().GetExprs().front();
212+
213+
// TODO(lshpku): the Reorder is unneeded if the later FactorizeReduction
214+
// supports rf_axis=1.
215+
loops = sch->GetLoops(block_id);
216+
sch->Reorder({loops[current_reduce_axis + 1], loops[current_reduce_axis]});
217+
VLOG(4) << "Before FactorizeReduction on block: [" << block_id
218+
<< "], loop nest:\n"
219+
<< sch->GetModule().GetExprs().front();
220+
221+
if (IsReduceBlock(context_->config, block_id)) {
222+
loops = sch->GetLoops(block_id);
223+
sch->FactorizeReduction(loops[current_reduce_axis],
224+
/* rf_axis = */ 0,
225+
/* with_write_back_block_init = */ false);
226+
}
227+
}
228+
VLOG(4) << "After SplitReduce on block: [" << block_id << "], loop nest:\n"
229+
<< sch->GetModule().GetExprs().front();
230+
231+
// Bind CUDA info
232+
const auto DoBind = [&](const std::vector<ir::Expr>& loops) {
233+
std::string sp_axis_type = "threadIdx.y";
234+
std::string rd_axis_type = "threadIdx.x";
235+
sch->Bind(loops[0], "blockIdx.x");
236+
if (!vec_flatten_axis_.empty() && sp_thread > 1) {
237+
if (vec_reduce_axis_.empty()) {
238+
sch->Bind(loops[current_reduce_axis - 1], rd_axis_type);
239+
} else {
240+
sch->Bind(loops[current_reduce_axis - 1], sp_axis_type);
241+
}
242+
}
243+
if (!vec_reduce_axis_.empty() && current_reduce_axis > 0) {
244+
sch->Bind(loops[current_reduce_axis], rd_axis_type);
245+
}
246+
};
247+
DoBind(sch->GetLoops(block_id));
248+
if (IsReduceBlock(context_->config, block_id) &&
249+
sch->HasBlock(block_id + "_rf")) {
250+
DoBind(sch->GetLoops(block_id + "_rf"));
251+
}
252+
VLOG(4) << "After BindCudaInfo on block: [" << block_id << "], loop nest:\n"
253+
<< sch->GetModule().GetExprs().front();
254+
255+
VariableTypeAssignment(sch, block_id);
256+
SetReduceType(sch, block_id);
257+
}
258+
139259
void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch,
140260
const std::string& block_id) {
141261
if (vec_flatten_axis_.size() >= 2) {

paddle/cinn/optim/resize_buffer.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
249249
ir::Store* store = expr->As<ir::Store>();
250250
ir::Tensor tensor = store->tensor.as_tensor_ref();
251251
ResizeTensor(&tensor);
252+
ReplaceTensorIndices<ir::Store>(store);
252253
ir::IRMutator<>::Visit(op, expr);
253254
}
254255

@@ -277,6 +278,7 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
277278
for (int i = 0; i < cnt; i++) {
278279
load->indices.erase(load->indices.begin());
279280
}
281+
ReplaceTensorIndices<ir::Load>(load);
280282
ir::IRMutator<>::Visit(op, expr);
281283
}
282284

@@ -304,6 +306,35 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> {
304306
}
305307
}
306308

309+
template <typename T>
310+
void ReplaceTensorIndices(T* op) {
311+
ir::Tensor tensor = op->tensor.as_tensor_ref();
312+
ir::Buffer buffer = tensor->buffer;
313+
if (!buffer.defined()) return;
314+
if (buffer->memory_type != ir::MemoryType::GPULocal) return;
315+
316+
VLOG(4) << "replacing index of tensor: " << tensor->name;
317+
ir::Expr index_expr = op->index();
318+
std::unordered_map<std::string, ir::Expr> var_name_to_expr;
319+
ir::ir_utils::CollectIRNodes(index_expr, [&](const ir::Expr* x) {
320+
const ir::_Var_* var = x->as_var();
321+
if (var) {
322+
var_name_to_expr[var->name] = var->Copy();
323+
}
324+
return false;
325+
});
326+
if (var_name_to_expr.size() != 1) {
327+
return;
328+
}
329+
330+
ir::Expr single_var = var_name_to_expr.begin()->second;
331+
VLOG(4) << "found single var: " << single_var;
332+
for (size_t i = 0; i + 1 < op->indices.size(); i++) {
333+
op->indices[i] = ir::Expr(0);
334+
}
335+
op->indices.back() = single_var;
336+
}
337+
307338
private:
308339
const std::unordered_map<std::string, std::vector<ir::Expr>>&
309340
buffer_name_to_shape_;

0 commit comments

Comments
 (0)