Skip to content

Commit b3cb80e

Browse files
authored
[CINN] Remove GPU-bound For loops more cleanly (#69417)
1 parent 3b27bf0 commit b3cb80e

File tree

8 files changed

+109
-70
lines changed

8 files changed

+109
-70
lines changed

paddle/cinn/backends/codegen_gpu_dev.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) {
148148

149149
std::vector<Expr> new_body;
150150

151+
auto axis_range_assumptions = op->PrepareAxisRangeAssumptions();
151152
auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs();
152153
auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs);
153154
auto alis_var_exprs = op->CudaAliasVarExprs();
@@ -156,6 +157,7 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) {
156157

157158
#define APPEND_TO_NEW_BODY(field__) \
158159
new_body.insert(std::end(new_body), std::begin(field__), std::end(field__));
160+
APPEND_TO_NEW_BODY(axis_range_assumptions)
159161
APPEND_TO_NEW_BODY(alloca_temp_buffers)
160162
APPEND_TO_NEW_BODY(temp_buffer_alias)
161163
APPEND_TO_NEW_BODY(alis_var_exprs)

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
386386
// 4.Apply low level pass
387387
if (i != func_bodies.size() - 1) {
388388
func = optim::Optimize(func, target_, false);
389-
optim::RearrangeLoadInstruction(&(func->body));
390389
} else {
391390
func = optim::Optimize(func, common::DefaultHostTarget(), false);
392391
}

paddle/cinn/ir/lowered_func.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,35 @@ void _LoweredFunc_::PrepareAllocOutputBufferExprs() {
151151
}
152152
}
153153

154+
std::vector<Expr> _LoweredFunc_::PrepareAxisRangeAssumptions() const {
155+
std::vector<Expr> assumption_exprs;
156+
157+
const auto AssumeAxisLT = [&](std::string axis, const Expr& dim_size) {
158+
if (!dim_size.defined()) {
159+
return;
160+
}
161+
if (dim_size == common::make_const(1)) {
162+
return;
163+
}
164+
Expr expr_lt = LT::Make(Var(axis), dim_size);
165+
Expr call_lt = Call::Make(Void(),
166+
runtime::intrinsic::cuda_builtin_assume,
167+
{expr_lt},
168+
{},
169+
CallType::Intrinsic);
170+
assumption_exprs.push_back(call_lt);
171+
};
172+
173+
AssumeAxisLT("blockIdx.x", cuda_axis_info.grid_dim(0));
174+
AssumeAxisLT("blockIdx.y", cuda_axis_info.grid_dim(1));
175+
AssumeAxisLT("blockIdx.z", cuda_axis_info.grid_dim(2));
176+
AssumeAxisLT("threadIdx.x", cuda_axis_info.block_dim(0));
177+
AssumeAxisLT("threadIdx.y", cuda_axis_info.block_dim(1));
178+
AssumeAxisLT("threadIdx.z", cuda_axis_info.block_dim(2));
179+
180+
return assumption_exprs;
181+
}
182+
154183
std::vector<Expr> _LoweredFunc_::PrepareAllocTempBufferExprs() const {
155184
std::vector<Expr> alloc_temp_buffer_exprs;
156185
for (auto& temp_buf : temp_bufs) {

paddle/cinn/ir/lowered_func.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ struct _LoweredFunc_ : public IrNode {
208208

209209
static const IrNodeTy _node_type_ = IrNodeTy::LoweredFunc;
210210

211+
//! Prepare the assumptions that a gpu axis should be less than its
212+
//! corresponding dim size, e.g. threadIdx.x < blockDim.x.
213+
std::vector<Expr> PrepareAxisRangeAssumptions() const;
211214
std::vector<Expr> PrepareCreateTempBufferExprs() const;
212215
//! Prepare the expressions for `alloc_tmp_buffer_exprs`.
213216
std::vector<Expr> PrepareAllocTempBufferExprs() const;

paddle/cinn/optim/optimize.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
7373
#ifdef CINN_WITH_CUDA
7474
ir::SetCudaAxisInfo(copied);
7575
if (remove_gpu_for_loops) {
76-
RemoveGpuForloopsAxis(copied);
76+
RemoveGpuForLoops(copied);
7777
}
7878
CudaSyncThreadsDropIfThenElse(copied);
7979
// CudaTransBufferWithDynamicShape(&copied);
@@ -83,7 +83,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
8383
#ifdef CINN_WITH_HIP
8484
ir::SetCudaAxisInfo(copied);
8585
if (remove_gpu_for_loops) {
86-
RemoveGpuForloopsAxis(copied);
86+
RemoveGpuForLoops(copied);
8787
}
8888
CudaSyncThreadsDropIfThenElse(copied);
8989
// CudaTransBufferWithDynamicShape(&copied);

paddle/cinn/optim/transform_gpu_forloop.cc

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "paddle/cinn/backends/cuda_util.h"
2424
#include "paddle/cinn/common/cas.h"
25+
#include "paddle/cinn/common/integer_set.h"
2526
#include "paddle/cinn/common/ir_util.h"
2627
#include "paddle/cinn/ir/ir.h"
2728
#include "paddle/cinn/ir/ir_mutator.h"
@@ -43,27 +44,17 @@ PD_DECLARE_bool(cinn_longlong2int_for_integer);
4344
namespace cinn {
4445
namespace optim {
4546

46-
/**
47-
* 1. Determine the grid and block dimensions.
48-
* It takes the domains like `[0, 20]` or `[0, min(20, M/2)]`, the domain should
49-
* have a integer right bound.
50-
*
51-
* 2. Replace the grid/thread iterators with something like `threadIdx.x`,
52-
* `threadIdx.y`.
53-
*
54-
* 3. Remove the forloops owning the gpu axis.
55-
* 1. if the extent is an IntImm, just remove this forloop.
56-
* 2. if the extent is a Min, replace the forloop with an IfThenElse, with
57-
* forloop's condition, new check will add (if the min of forloop is not zero).
58-
*
59-
* @param expr The expression to mutate.
60-
*/
61-
void RemoveGpuForloopsAxis(ir::LoweredFunc fn) {
47+
void RemoveGpuForLoops(ir::LoweredFunc fn) {
6248
struct Mutator : public ir::IRMutator<Expr *> {
6349
using ir::IRMutator<>::Visit;
64-
void operator()(ir::LoweredFunc fn) { Visit(fn.As<ir::_LoweredFunc_>()); }
50+
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
51+
52+
explicit Mutator(const ir::CudaAxisInfo &cuda_axis_info)
53+
: cuda_axis_info_(cuda_axis_info) {}
6554

6655
private:
56+
ir::CudaAxisInfo cuda_axis_info_;
57+
6758
void Visit(const ir::For *op, Expr *expr) override {
6859
switch (op->for_type()) {
6960
case ir::ForType::GPUBlock:
@@ -90,56 +81,64 @@ void RemoveGpuForloopsAxis(ir::LoweredFunc fn) {
9081
}
9182

9283
bool NeedToReplaceForloopWithIfThenElse(const ir::For *n) const {
84+
// If the loop doesn't start from 0.
85+
if (n->min != cinn::common::make_const(0)) {
86+
return true;
87+
}
88+
89+
// Get dim_size from the functions's cuda_axis_info as pre-condition.
90+
ir::Expr dim_size;
91+
switch (n->bind_info().for_type) {
92+
case ir::ForType::GPUThread:
93+
dim_size = cuda_axis_info_.block_dim(n->bind_info().offset);
94+
break;
95+
case ir::ForType::GPUBlock:
96+
dim_size = cuda_axis_info_.grid_dim(n->bind_info().offset);
97+
break;
98+
}
99+
if (!dim_size.defined()) {
100+
return true;
101+
}
102+
103+
// If we can prove the loop's extent >= dim_size, then it's safe not
104+
// to add the IfThenElse guard.
105+
common::cas_intervals_t var_intervals =
106+
common::CollectVarIntervalsOfExprs({n->extent, dim_size});
107+
common::SymbolicExprAnalyzer analyzer{var_intervals};
108+
std::optional<bool> proved_ge = analyzer.ProveGE(n->extent, dim_size);
109+
if (proved_ge.value_or(false)) {
110+
return false;
111+
}
93112
return true;
94113
}
95114

96115
void ReplaceForloopWithIfThenElse(Expr *expr) {
97116
auto *for_n = expr->As<ir::For>();
98-
auto *poly_for_n = expr->As<ir::PolyFor>();
99-
PADDLE_ENFORCE_EQ(for_n || poly_for_n,
100-
true,
101-
::common::errors::InvalidArgument(
102-
"PolyFor is not exist, please check."));
103117

104118
Expr condition;
105-
106-
auto condition_append = [&](Expr new_cond) {
119+
const auto AppendCondition = [&](Expr new_cond) {
107120
if (condition.defined()) {
108121
condition = ir::And::Make(condition, new_cond);
109122
} else {
110123
condition = new_cond;
111124
}
112125
};
113126

114-
if (for_n) {
115-
// for(i, 2, 100);
116-
// ^
117-
if (for_n->min != cinn::common::make_const(0)) {
118-
condition_append(ir::GE::Make(for_n->loop_var, for_n->min));
119-
}
120-
121-
// for(i, 2, min(M/2, 20)
122-
// ^
123-
condition_append(ir::LT::Make(for_n->loop_var, for_n->extent));
124-
} else {
125-
if (poly_for_n->init != cinn::common::make_const(0)) {
126-
condition_append(
127-
ir::GE::Make(poly_for_n->iterator, poly_for_n->init));
128-
}
129-
130-
condition_append(poly_for_n->condition);
127+
// for(i, 2, 100);
128+
// ^
129+
if (for_n->min != cinn::common::make_const(0)) {
130+
AppendCondition(ir::GE::Make(for_n->loop_var, for_n->min));
131131
}
132+
// for(i, 2, min(M/2, 20)
133+
// ^
134+
AppendCondition(ir::LT::Make(for_n->loop_var, for_n->extent));
132135

133136
PADDLE_ENFORCE_EQ(condition.defined(),
134137
true,
135138
::common::errors::InvalidArgument(
136139
"Condition is not defined, please check."));
137140

138-
VLOG(3) << "GPU replacing\n" << *expr;
139-
VLOG(3) << "\nto\n";
140-
auto if_n = ir::IfThenElse::Make(condition, for_n->body);
141-
VLOG(3) << if_n;
142-
*expr = if_n;
141+
*expr = ir::IfThenElse::Make(condition, for_n->body);
143142
}
144143

145144
void Visit(const ir::PolyFor *op, Expr *expr) override {
@@ -163,8 +162,8 @@ void RemoveGpuForloopsAxis(ir::LoweredFunc fn) {
163162
}
164163
};
165164

166-
Mutator mutator;
167-
mutator(fn);
165+
Mutator mutator(fn->cuda_axis_info);
166+
mutator(&fn->body);
168167
}
169168

170169
/**

paddle/cinn/optim/transform_gpu_forloop.h

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,34 @@ void OptimizeExprGPU(Expr* expr);
3333
*/
3434

3535
/**
36-
* Remove the forloops of block and thread axis, add the kernel launch thread
37-
* dimension information to the outermost LoweredFunc.
36+
* Remove the GPU block/thread-bound For loops, add IfThenElse guards if needed.
3837
*
39-
* For example, input the code:
40-
* \code
41-
* // Note here, the outermost expression should be a LoweredFunc
42-
* _LoweredFunc_:
43-
* for (blockIdx.x, 0, 10)
44-
* for (threadIdx.x, 0, 20)
45-
* A(blockIdx.x, threadIdx.x)
46-
* \endcode
38+
* It's usually safe to remove bound loops, because when launching the kernel,
39+
* we are expected to choose dim sizes that match the extents of these loops.
40+
* However, there are cases where we cannot simply remove a loop, but need to
41+
* add an IfThenElse as guard:
42+
* 1) if the loop doesn't start from 0.
43+
* 2) if we cannot prove that the loop's extent is always equal to or greater
44+
* than the corresponding dim size.
4745
*
48-
* will be modified to
49-
* \code
50-
* _LoweredFunc_<blockDim:10, threadDim:20>:
51-
* A(blockIdx.x, threadIdx.x)
52-
* \endcode
46+
* Example 1:
47+
* # assume blockDim.x == 256
48+
* thread_bind[threadIdx.x] for (k, 0, 256):
49+
* ScheduleBlock(A)
50+
* =>
51+
* ScheduleBlock(A)
5352
*
54-
* \note For that the dimensions of each threadIdx or blockIdx should be
55-
* constant, so this only takes For nodes, not \note PolyFor nodes is allowed to
56-
* be GPU related.
53+
* Example 2:
54+
* # assume gridDim.x == 8
55+
* thread_bind[blockIdx.x] for (k, 2, min(S0, 8)):
56+
* ScheduleBlock(A)
57+
* =>
58+
* if (blockIdx.x >= 2 && blockIdx.x < min(S0, 8)):
59+
* ScheduleBlock(A)
60+
*
61+
* @param fn The LoweredFunc to process.
5762
*/
58-
void RemoveGpuForloopsAxis(ir::LoweredFunc fn);
63+
void RemoveGpuForLoops(ir::LoweredFunc fn);
5964

6065
/**
6166
* Add __syncthreads() to shared memory producer.

paddle/cinn/runtime/intrinsic.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ static const char* debug_log_repr = "cinn_print_debug_string";
129129

130130
static const char* cuda_sync_threads = "__syncthreads";
131131

132+
static const char* cuda_builtin_assume = "__builtin_assume";
133+
132134
static const char* parallel_launch = "cinn_backend_parallel_launch";
133135

134136
} // namespace intrinsic

0 commit comments

Comments
 (0)