Skip to content

Commit 1f98e30

Browse files
authored
Refine bind gpu axis (PaddlePaddle#124)
2 parents cdb90c1 + 4817f5d commit 1f98e30

12 files changed

+204
-45
lines changed

cinn/backends/codegen_c.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,12 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
524524
<< "the count of allocation and deallocaton expressions is not match";
525525

526526
std::vector<Expr> new_body;
527+
528+
auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs();
527529
#define APPEND_TO_NEW_BODY(field__) new_body.insert(std::end(new_body), std::begin(op->field__), std::end(op->field__));
528530
APPEND_TO_NEW_BODY(argument_prepare_exprs)
529531
APPEND_TO_NEW_BODY(alloc_output_buffer_exprs)
530-
APPEND_TO_NEW_BODY(alloc_tmp_buffer_exprs)
532+
new_body.insert(std::end(new_body), std::begin(alloca_temp_buffers), std::end(alloca_temp_buffers));
531533
APPEND_TO_NEW_BODY(buffer_data_cast_exprs)
532534
new_body.push_back(op->body);
533535
APPEND_TO_NEW_BODY(dealloc_output_buffer_exprs)

cinn/backends/codegen_cuda_dev.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void CodeGenCUDA_Dev::Visit(const ir::_LoweredFunc_ *op) {
7979

8080
DoIndent();
8181

82-
Expr temp_buffer_alloc = ir::Block::Make(op->alloc_tmp_buffer_exprs);
82+
Expr temp_buffer_alloc = ir::Block::Make(op->PrepareAllocTempBufferExprs());
8383
Expr func_body = op->body;
8484
Expr temp_buffer_alias = ir::Block::Make(GenerateBufferAliasExprs(op, op->temp_bufs));
8585

cinn/backends/codegen_cuda_dev_test.cc

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,9 @@ TEST(CodeGenCUDA, basic) {
5858

5959
auto C = Compute(
6060
{M, N}, [&](Var i, Var j) { return A(i, j) * B(i, j); }, "C");
61-
C->WithBuffer();
6261

63-
C->stage()->GpuBlocks({C->stage()->axis(0)});
64-
C->stage()->GpuThreads({C->stage()->axis(1)});
62+
C->stage()->Bind(0, "blockIdx.x");
63+
C->stage()->Bind(1, "threadIdx.x");
6564

6665
CodeGenCUDA_Dev codegen(target);
6766

@@ -883,5 +882,71 @@ TEST(Conv, optimize) {
883882
LOG(INFO) << Lower("conv", {A, W, BL}, {}, {AA, WW, AL, WL, B});
884883
}
885884

885+
TEST(ElementwiseAdd, cache_read) {
886+
Expr M(100);
887+
Expr N(200);
888+
889+
Placeholder<float> A("A", {M, N});
890+
Placeholder<float> B("B", {M, N});
891+
892+
auto C = Compute(
893+
{M, N}, [&](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C");
894+
C->stage()->Split(1, 10);
895+
896+
auto AL = A->stage()->CacheRead("local", {C});
897+
AL->stage()->Split(1, 10);
898+
899+
AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name);
900+
C->stage()->Bind(0, "threadIdx.x");
901+
C->stage()->Bind(1, "blockIdx.x");
902+
903+
Target target;
904+
CodeGenCUDA_Dev codegen(target);
905+
906+
auto fn = Lower("fn", {A, B, C}, {}, {AL});
907+
908+
Module::Builder builder("module", target);
909+
builder.AddFunction(fn);
910+
911+
auto source_code = codegen.Compile(builder.Build());
912+
LOG(INFO) << "source:\n" << source_code;
913+
914+
std::string source_target = R"ROC(
915+
extern "C" {
916+
917+
#ifdef __CUDACC_RTC__
918+
typedef int int32_t;
919+
typedef char int8_t;
920+
#endif
921+
922+
923+
924+
__global__
925+
void fn_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
926+
{
927+
float _A_read_cache_3 [ 1 * 10 ];
928+
float* A_read_cache_3 = _A_read_cache_3;
929+
{
930+
if (((((threadIdx.x >= 0) && (threadIdx.x <= 99)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
931+
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
932+
A_read_cache_3[j_inner] = A[((10 * blockIdx.x) + ((200 * threadIdx.x) + j_inner))];
933+
};
934+
};
935+
for (int32_t i = 0; i < 10; i += 1) {
936+
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[((10 * blockIdx.x) + ((10 * threadIdx.x) + i))] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
937+
};
938+
};
939+
}
940+
941+
}
942+
)ROC";
943+
// ASSERT_EQ(utils::Trim(source_target), source);
944+
945+
backends::NVRTC_Compiler compiler;
946+
947+
auto ptx = compiler(source_code);
948+
CHECK(!ptx.empty()) << "Compile error!";
949+
}
950+
886951
} // namespace backends
887952
} // namespace cinn

cinn/common/ir_util.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,5 +337,49 @@ Expr cast(Expr e, Type type) {
337337
return ir::Cast::Make(type, e);
338338
}
339339

340+
std::vector<std::string> GatherItersToTensorProducer(const std::string &target_tensor_name, Expr *expr) {
341+
struct Visitor : public ir::IRMutator<> {
342+
std::vector<std::string> iters;
343+
const std::string &target_tensor_name;
344+
345+
Visitor(const std::string &target_tensor_name) : target_tensor_name(target_tensor_name) {}
346+
347+
std::vector<std::string> operator()(Expr *expr) {
348+
ir::IRMutator<>::Visit(expr, expr);
349+
return iters;
350+
}
351+
352+
void Visit(const ir::Store *op, Expr *expr) {
353+
if (op->tensor.as_tensor()->name == target_tensor_name) {
354+
CHECK(iters.empty());
355+
for (auto &e : for_stack) {
356+
auto *for_n = e->As<ir::For>();
357+
auto *polyfor_n = e->As<ir::PolyFor>();
358+
if (for_n) {
359+
iters.push_back(for_n->loop_var->name);
360+
} else {
361+
iters.push_back(polyfor_n->iterator->name);
362+
}
363+
}
364+
}
365+
}
366+
367+
void Visit(const ir::For *op, Expr *expr) {
368+
for_stack.push_back(expr);
369+
ir::IRMutator<>::Visit(op, expr);
370+
for_stack.pop_back();
371+
}
372+
void Visit(const ir::PolyFor *op, Expr *expr) {
373+
for_stack.push_back(expr);
374+
ir::IRMutator<>::Visit(op, expr);
375+
for_stack.pop_back();
376+
}
377+
378+
std::vector<Expr *> for_stack;
379+
};
380+
381+
return Visitor(target_tensor_name)(expr);
382+
}
383+
340384
} // namespace common
341385
} // namespace cinn

cinn/common/ir_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ void UnifyAllTensorsInExpr(Expr *expr);
7676
*/
7777
void UnifyAllBuffersInExpr(Expr *Expr);
7878

79+
std::vector<std::string> GatherItersToTensorProducer(const std::string &target_tensor_name, Expr *expr);
80+
7981
bool is_zero(Expr v);
8082

8183
bool MathEqual(const Expr &a, const Expr &b);

cinn/ir/lowered_func.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ void _LoweredFunc_::PrepareAllocOutputBufferExprs() {
7171
}
7272
}
7373

74-
void _LoweredFunc_::PrepareAllocTempBufferExprs() {
74+
std::vector<Expr> _LoweredFunc_::PrepareAllocTempBufferExprs() const {
75+
std::vector<Expr> alloc_output_buffer_exprs;
7576
for (auto& temp_buf : temp_bufs) {
76-
alloc_tmp_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr()));
77+
alloc_output_buffer_exprs.push_back(Alloc::Make(temp_buf, temp_buf->type(), temp_buf->shape, Expr(), Expr()));
7778
}
79+
return alloc_output_buffer_exprs;
7880
}
7981

8082
void _LoweredFunc_::PrepareDeallocOutputBufferExprs() {

cinn/ir/lowered_func.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
105105
std::vector<Expr> dealloc_output_buffer_exprs;
106106
// @}
107107

108-
std::vector<Expr> alloc_tmp_buffer_exprs;
109108
//! something like: float* A_data = (float*)(A->host_memory);
110109
std::vector<Expr> buffer_data_cast_exprs;
111110

@@ -123,12 +122,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
123122

124123
static const IrNodeTy _node_type_ = IrNodeTy::_LoweredFunc_;
125124

125+
//! Prepare the expressions for `alloc_tmp_buffer_exprs`.
126+
std::vector<Expr> PrepareAllocTempBufferExprs() const;
127+
126128
private:
127129
void CheckValid() const;
128130
//! Prepare the expressions for `alloc_output_buffer_exprs`.
129131
void PrepareAllocOutputBufferExprs();
130-
//! Prepare the expressions for `alloc_tmp_buffer_exprs`.
131-
void PrepareAllocTempBufferExprs();
132132
//! Prepare the expressions for `dealloc_output_buffer_exprs`.
133133
void PrepareDeallocOutputBufferExprs();
134134
//! Insert the allocation expr for temporary variables.

cinn/lang/lower_impl.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,15 @@ Expr LowerGroup(const poly::ScheduleGroup& group, const std::map<std::string, Ex
105105
{
106106
optim::forloop_infos_t forloop_infos;
107107
for (auto* stage : stages) {
108-
forloop_infos[stage->id()] = stage->forloop_infos();
108+
// transform the level identified for infors to iter name identified.
109+
auto iters = common::GatherItersToTensorProducer(stage->id(), &e);
110+
std::map<std::string, poly::StageForloopInfo> for_infos;
111+
for (auto& item : stage->forloop_infos()) {
112+
CHECK_LT(item.first, iters.size());
113+
for_infos[iters[item.first]] = item.second;
114+
}
115+
116+
forloop_infos[stage->id()] = for_infos;
109117
}
110118
optim::TransformGpuForloop(forloop_infos, &e);
111119
}
@@ -772,13 +780,6 @@ void UpdateComputeAtBufferShape(Expr* expr) {
772780
process_buffer(Reference(&buf).operator->(), *compute_at_it->second);
773781
}
774782
}
775-
776-
for (auto& expr : node->alloc_tmp_buffer_exprs) {
777-
auto compute_at_it = buffer_to_compute_at_info.find(expr.As<ir::Alloc>()->destination.as_buffer()->name);
778-
if (compute_at_it != buffer_to_compute_at_info.end()) {
779-
process_alloca(Reference(&expr).As<ir::Alloc>(), *compute_at_it->second);
780-
}
781-
}
782783
}
783784
}
784785

cinn/optim/ir_copy.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
218218

219219
std::vector<Expr> alloc_output_buffer_exprs;
220220
std::vector<Expr> dealloc_output_buffer_exprs;
221-
std::vector<Expr> alloc_tmp_buffer_exprs;
222221
std::vector<Expr> buffer_data_cast_exprs;
223222
std::vector<Expr> argument_prepare_exprs;
224223

@@ -230,7 +229,6 @@ struct IRCopyVisitor : public ir::IRVisitorBase<Expr> {
230229

231230
COPY_ADD_FIELD(alloc_output_buffer_exprs);
232231
COPY_ADD_FIELD(dealloc_output_buffer_exprs);
233-
COPY_ADD_FIELD(alloc_tmp_buffer_exprs);
234232
COPY_ADD_FIELD(buffer_data_cast_exprs);
235233
COPY_ADD_FIELD(argument_prepare_exprs);
236234

cinn/poly/compute_at_transform.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ ComputeAtTransform::ComputeAtTransform(
151151
ptransform_(ptransform),
152152
ctransform_(ctransform),
153153
level_(level) {
154-
LOG(INFO) << "pdomain: " << pdomain;
155-
LOG(INFO) << "ptransform: " << ptransform;
156-
LOG(INFO) << "cdomain: " << cdomain;
157-
LOG(INFO) << "ctransform: " << ctransform;
154+
VLOG(2) << "pdomain: " << pdomain;
155+
VLOG(2) << "ptransform: " << ptransform;
156+
VLOG(2) << "cdomain: " << cdomain;
157+
VLOG(2) << "ctransform: " << ctransform;
158+
VLOG(2) << "access: " << access;
158159

159160
adjusted_ctransform_ = isl::manage(AddParamsTo(ctransform_.copy()));
160161
adjusted_cdomain_ = isl::manage(AddParamsTo(cdomain_.copy()));

cinn/poly/stage.cc

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,14 +172,27 @@ void Stage::ComputeAtSchedule(Stage *other, int level, ComputeAtKind kind) {
172172
}
173173
}
174174

175-
void Stage::ComputeAt(Stage *other, int level, Stage::ComputeAtKind kind) {
176-
auto accesses = GatherAccesses(other, tensor_->name);
177-
if (accesses.empty()) return;
178-
auto access = accesses[0];
179-
for (int i = 1; i < accesses.size(); i++) {
180-
access = isl::manage(isl_map_union(access.release(), accesses[i].copy()));
175+
void Stage::ComputeAt(Stage *other, int level, Stage::ComputeAtKind kind, const std::string &cached_tensor_name) {
176+
isl::map access;
177+
isl_map *access_raw{};
178+
// For cache_read schedule, it will replace the producer tensor with cache in consumer, so replace the tuple name to
179+
// cache's in access.
180+
if (cached_tensor_name.empty())
181+
access_raw = GatherAccesses(other, tensor_->name);
182+
else
183+
access_raw = GatherAccesses(other, cached_tensor_name);
184+
185+
if (!access_raw) {
186+
LOG(ERROR) << "ComputeAt: " << other->tensor_->name << " has no access to " << tensor_->name << ", skipped it";
187+
return;
181188
}
182189

190+
if (!cached_tensor_name.empty()) {
191+
access_raw = isl_map_set_tuple_name(access_raw, isl_dim_out, tensor_->name.c_str());
192+
}
193+
access = isl::manage(access_raw);
194+
access_raw = nullptr;
195+
183196
ComputeAtTransform transform(domain_, other->domain(), access, transform_, other->transform(), level);
184197
transform();
185198

@@ -457,8 +470,9 @@ std::vector<std::string> Stage::axis_names() const { return GetDimNames(transfor
457470
void Stage::GpuThreads(const std::vector<Iterator> &iters, DeviceAPI device) {
458471
auto dim_names = axis_names();
459472
for (auto &iter : iters) {
460-
CHECK(std::find(dim_names.begin(), dim_names.end(), iter.id) != dim_names.end());
461-
forloop_infos_.emplace(iter.id, StageForloopInfo{ir::ForType::GPUThread, device});
473+
auto it = std::find(dim_names.begin(), dim_names.end(), iter.id);
474+
CHECK(it != dim_names.end());
475+
AddForloopInfo(it - dim_names.begin(), StageForloopInfo{ir::ForType::GPUThread, device});
462476
}
463477
}
464478

@@ -471,7 +485,6 @@ void Stage::GpuBlocks(const std::vector<int> &levels, DeviceAPI device) {
471485
levels.begin(), levels.end(), std::back_inserter(iters), [&](int i) { return Iterator(dim_names[i]); });
472486
GpuBlocks(iters, device);
473487
}
474-
475488
void Stage::GpuBlocks(const Iterator &block_x, DeviceAPI device) {
476489
GpuBlocks(std::vector<Iterator>({block_x}), device);
477490
}
@@ -484,8 +497,21 @@ void Stage::GpuBlocks(const Iterator &block_x, const Iterator &block_y, const It
484497
void Stage::GpuBlocks(const std::vector<Iterator> &iters, DeviceAPI device) {
485498
auto dim_names = axis_names();
486499
for (auto &iter : iters) {
487-
CHECK(std::find(dim_names.begin(), dim_names.end(), iter.id) != dim_names.end());
488-
forloop_infos_.emplace(iter.id, StageForloopInfo{ir::ForType::GPUBlock, device});
500+
auto it = std::find(dim_names.begin(), dim_names.end(), iter.id);
501+
CHECK(it != dim_names.end());
502+
AddForloopInfo(it - dim_names.begin(), StageForloopInfo{ir::ForType::GPUBlock, device});
503+
}
504+
}
505+
void Stage::Bind(int level, const std::string &axis) {
506+
auto dim_names = GetDimNames(transformed_domain().get());
507+
CHECK_LT(level, dim_names.size());
508+
509+
if (axis == "threadIdx.x" || axis == "threadIdx.y" || axis == "threadIdx.z") {
510+
AddForloopInfo(level, StageForloopInfo{ir::ForType::GPUThread, DeviceAPI::GPU});
511+
} else if (axis == "blockIdx.x" || axis == "blockIdx.y" || axis == "blockIdx.z") {
512+
AddForloopInfo(level, StageForloopInfo{ir::ForType::GPUBlock, DeviceAPI::GPU});
513+
} else {
514+
NOT_IMPLEMENTED
489515
}
490516
}
491517

@@ -573,7 +599,7 @@ void Stage::ShareBufferWith(ir::Tensor other) {
573599

574600
void Stage::CtrlDepend(const ir::Tensor &t) { add_extra_depend_stage(t->name); }
575601

576-
std::vector<isl::map> GatherAccesses(Stage *stage, const std::string &tensor_name) {
602+
isl_map *__isl_give GatherAccesses(Stage *stage, const std::string &tensor_name) {
577603
CHECK(stage->tensor_);
578604
auto loads = ir::CollectIRNodes(stage->tensor_->body(), [&](const Expr *x) {
579605
return x->As<ir::Load>() && x->As<ir::Load>()->tensor.as_tensor()->name == tensor_name;
@@ -588,16 +614,26 @@ std::vector<isl::map> GatherAccesses(Stage *stage, const std::string &tensor_nam
588614
std::transform(
589615
loads.begin(), loads.end(), std::back_inserter(out_loads), [](const Expr &x) { return utils::GetStreamCnt(x); });
590616

591-
std::vector<isl::map> res;
592-
617+
isl_map *res = nullptr;
593618
for (auto &load : out_loads) {
594619
std::string repr = utils::StringFormat(
595620
"{ %s[%s] -> %s }", in_tuple_name.c_str(), utils::Join(in_dim_names, ",").c_str(), load.c_str());
596-
res.push_back(isl::map(stage->domain().ctx(), repr));
621+
isl_map *access = isl_map_read_from_str(stage->domain().ctx().get(), repr.c_str());
622+
if (res) {
623+
res = isl_map_union(res, access);
624+
} else {
625+
res = access;
626+
}
597627
}
598628

599629
return res;
600630
}
601631

632+
void Stage::AddForloopInfo(int level, const StageForloopInfo &info) {
633+
int num_levels = isl_map_dim(transform_.get(), isl_dim_out);
634+
CHECK_LT(level, num_levels);
635+
forloop_infos_[level] = info;
636+
}
637+
602638
} // namespace poly
603639
} // namespace cinn

0 commit comments

Comments
 (0)