Skip to content

Commit fe0c647

Browse files
authored
fix CacheRead conflict with ComputeAt (PaddlePaddle#126)
1 parent 1f98e30 commit fe0c647

18 files changed

+174
-45
lines changed

cinn/backends/codegen_cuda_dev_test.cc

+77-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace backends {
2424

2525
std::tuple<CUdeviceptr, CUdeviceptr, CUdeviceptr, std::vector<float>, std::vector<float>, std::vector<float>>
2626
CreateNVMemory(int M, int N) {
27-
CUDA_CALL(cudaThreadSynchronize());
27+
CUDA_CALL(cudaDeviceSynchronize());
2828

2929
CUdeviceptr Ad, Bd, Cd;
3030
cuMemAlloc(&Ad, M * N * sizeof(float));
@@ -419,7 +419,7 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
419419
B_buf->host_memory = reinterpret_cast<uint8_t*>(Bd);
420420
C_buf->host_memory = reinterpret_cast<uint8_t*>(Cd);
421421

422-
CUDA_CALL(cudaThreadSynchronize());
422+
CUDA_CALL(cudaDeviceSynchronize());
423423

424424
// call the kernel
425425
auto comp = reinterpret_cast<void (*)(cinn_pod_value_t*, int)>(fn_ptr);
@@ -428,7 +428,7 @@ TEST(CodeGenCUDA, jit_host_call_cuda_kernel) {
428428

429429
comp(args.data(), args.size());
430430

431-
CUDA_CALL(cudaThreadSynchronize());
431+
CUDA_CALL(cudaDeviceSynchronize());
432432

433433
CUDA_CALL(cudaMemcpy(host_data3.data(),
434434
reinterpret_cast<void*>(Cd),
@@ -716,7 +716,7 @@ TEST(elementwise_add, share_local_cache) {
716716
B_buf->host_memory = reinterpret_cast<uint8_t*>(Bd);
717717
C_buf->host_memory = reinterpret_cast<uint8_t*>(Cd);
718718

719-
CUDA_CALL(cudaThreadSynchronize());
719+
CUDA_CALL(cudaDeviceSynchronize());
720720

721721
// call the kernel
722722
auto comp = reinterpret_cast<void (*)(cinn_pod_value_t*, int)>(fn_ptr);
@@ -725,7 +725,7 @@ TEST(elementwise_add, share_local_cache) {
725725

726726
comp(args.data(), args.size());
727727

728-
CUDA_CALL(cudaThreadSynchronize());
728+
CUDA_CALL(cudaDeviceSynchronize());
729729
}
730730

731731
CUDA_CALL(cudaFree(reinterpret_cast<void*>(Ad)))
@@ -883,6 +883,8 @@ TEST(Conv, optimize) {
883883
}
884884

885885
TEST(ElementwiseAdd, cache_read) {
886+
Context::Global().ResetNameId();
887+
886888
Expr M(100);
887889
Expr N(200);
888890

@@ -933,14 +935,82 @@ void fn_kernel(const float* __restrict__ A, const float* __restrict__ B, float*
933935
};
934936
};
935937
for (int32_t i = 0; i < 10; i += 1) {
936-
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[((10 * blockIdx.x) + ((10 * threadIdx.x) + i))] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
938+
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_3[i] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))]);
939+
};
940+
};
941+
}
942+
943+
}
944+
)ROC";
945+
ASSERT_EQ(utils::Trim(source_target), source_code);
946+
947+
backends::NVRTC_Compiler compiler;
948+
949+
auto ptx = compiler(source_code);
950+
CHECK(!ptx.empty()) << "Compile error!";
951+
}
952+
953+
TEST(ElementwiseAdd, cache_read1) {
954+
Expr M(100);
955+
Expr N(200);
956+
957+
Placeholder<float> A("A", {M, N});
958+
Placeholder<float> B("B", {M, N});
959+
960+
auto C = Compute(
961+
{M - 2, N}, [&](Expr i, Expr j) { return A(i, j) + A(i + 1, j) + A(i + 2, j) + B(i, j); }, "C");
962+
C->stage()->Split(1, 10);
963+
964+
auto AL = A->stage()->CacheRead("local", {C});
965+
AL->stage()->Split(1, 10);
966+
967+
AL->stage()->ComputeAt(C->stage(), 1, poly::Stage::ComputeAtKind::kComputeAtUnk, A->name);
968+
C->stage()->Bind(0, "threadIdx.x");
969+
C->stage()->Bind(1, "blockIdx.x");
970+
971+
Target target;
972+
CodeGenCUDA_Dev codegen(target);
973+
974+
auto fn = Lower("fn", {A, B, C}, {}, {AL});
975+
976+
Module::Builder builder("module", target);
977+
builder.AddFunction(fn);
978+
979+
auto source_code = codegen.Compile(builder.Build());
980+
std::cout << "source:\n" << source_code << std::endl;
981+
982+
std::string source_target = R"ROC(
983+
extern "C" {
984+
985+
#ifdef __CUDACC_RTC__
986+
typedef int int32_t;
987+
typedef char int8_t;
988+
#endif
989+
990+
991+
992+
__global__
993+
void fn_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
994+
{
995+
float _A_read_cache_6 [ 3 * 10 ];
996+
float* A_read_cache_6 = _A_read_cache_6;
997+
{
998+
if (((((threadIdx.x >= 0) && (threadIdx.x <= 97)) && (blockIdx.x >= 0)) && (blockIdx.x <= 19))) {
999+
for (int32_t i = threadIdx.x; i < (3 + threadIdx.x); i += 1) {
1000+
for (int32_t j_inner = 0; j_inner < 10; j_inner += 1) {
1001+
A_read_cache_6[((10 * i) + j_inner)] = A[((10 * blockIdx.x) + ((200 * i) + j_inner))];
1002+
};
1003+
};
1004+
};
1005+
for (int32_t i = 0; i < 10; i += 1) {
1006+
C[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))] = (A_read_cache_6[i] + (A_read_cache_6[(10 + i)] + (A_read_cache_6[(20 + i)] + B[((10 * blockIdx.x) + ((200 * threadIdx.x) + i))])));
9371007
};
9381008
};
9391009
}
9401010
9411011
}
9421012
)ROC";
943-
// ASSERT_EQ(utils::Trim(source_target), source);
1013+
ASSERT_EQ(utils::Trim(source_target), source_code);
9441014

9451015
backends::NVRTC_Compiler compiler;
9461016

cinn/common/cas.cc

+1
Original file line numberDiff line numberDiff line change
@@ -1478,6 +1478,7 @@ Expr SolveInequality(Expr inequality, Var val) {
14781478
} else {
14791479
return AutoSimplify(inequality);
14801480
}
1481+
return Expr();
14811482
}
14821483

14831484
} // namespace common

cinn/hlir/pe/add.h

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22
#include <string>
33
#include <vector>
4+
45
#include "cinn/common/common.h"
56
#include "cinn/ir/ir.h"
67
#include "cinn/ir/node.h"

cinn/lang/lower_impl.cc

+13-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cinn/common/ir_util.h"
88
#include "cinn/ir/ir_printer.h"
99
#include "cinn/lang/tensor.h"
10+
#include "cinn/optim/cache_read_write_replace.h"
1011
#include "cinn/optim/ir_replace.h"
1112
#include "cinn/poly/compute_at_transform.h"
1213

@@ -30,7 +31,9 @@ void CheckNoIslCallRemains(Expr* expr) {
3031
}
3132
}
3233

33-
Expr LowerGroup(const poly::ScheduleGroup& group, const std::map<std::string, Expr>& tuple_to_expr) {
34+
Expr LowerGroup(const poly::ScheduleGroup& group,
35+
const std::map<std::string, Expr>& tuple_to_expr,
36+
std::map<std::string, ir::Tensor>* global_tensor_map) {
3437
std::vector<poly::Stage*> stages;
3538
for (auto& node : group.nodes) {
3639
if (node->stage->has_expression()) {
@@ -73,6 +76,8 @@ Expr LowerGroup(const poly::ScheduleGroup& group, const std::map<std::string, Ex
7376
}
7477
CheckNoIslCallRemains(&e);
7578

79+
optim::CacheReadWriteReplace(&e, global_tensor_map);
80+
7681
// deal with the compute_at relations
7782
ProcessComputeAtInfo(&e);
7883

@@ -375,6 +380,8 @@ Expr LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) {
375380
auto tensor_map = GenAllTensorMap();
376381
std::map<std::string, Expr> tuple_to_expr;
377382
CHECK(!schedule->groups.empty()) << "no group is generated";
383+
384+
std::map<std::string, ir::Tensor> global_tensor_map;
378385
for (auto& group : schedule->groups) {
379386
CHECK_GT(group.nodes.size(), 0) << "group is empty";
380387
for (auto& node : group.nodes) {
@@ -384,7 +391,7 @@ Expr LowerImpl::GenerateFunctionBody(const poly::Schedule* schedule) {
384391
tuple_to_expr[tensor->name] = tensor->tensor_store_expanded_body();
385392
}
386393

387-
Expr group_expr = LowerGroup(group, tuple_to_expr);
394+
Expr group_expr = LowerGroup(group, tuple_to_expr, &global_tensor_map);
388395
if (group_expr.defined()) {
389396
VLOG(3) << "group expr:\n" << group_expr;
390397
exprs.push_back(group_expr);
@@ -530,8 +537,10 @@ struct CorrectComputeAtRelatedIndiceMutator : public ir::IRMutator<> {
530537
auto* node = expr->As<ir::Store>();
531538
CHECK(node);
532539

540+
VLOG(3) << "SetProducerAxisToZeroInStore: " << *expr;
533541
for (auto& indice : node->indices) {
534542
for (auto& consumer_axis : consumer_axis) {
543+
VLOG(3) << indice << " set producer axis [" << consumer_axis << "] to 0";
535544
optim::IrReplace(&indice, consumer_axis, common::make_const(0));
536545
}
537546
}
@@ -626,11 +635,13 @@ struct CorrectComputeAtRelatedIndiceMutator : public ir::IRMutator<> {
626635
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
627636

628637
void Visit(const ir::Load* op, Expr* expr) override {
638+
VLOG(3) << "Consumer modify Load " << *expr << "'s axis for producer [" << producer_tensor_name << "]";
629639
auto* node = expr->As<ir::Load>();
630640
if (op->tensor.as_tensor()->name == producer_tensor_name) {
631641
CHECK_LE(compute_at_info.preceding_offset_for_producer_load.size(), node->indices.size());
632642
for (auto axis : consumer_axis) {
633643
for (auto& indice : node->indices) {
644+
VLOG(3) << "Consumer Load " << indice << " set axis [" << axis << "] to 0";
634645
optim::IrReplace(&indice, axis, common::make_const(0));
635646
}
636647
}

cinn/lang/lower_impl.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ void CheckNoIslCallRemains(const Expr* expr);
4848
* @param group A single schedule group containing several Stages and the scheduling order.
4949
* @param tuple_to_expr A map from isl set tuple name to CINN expressions.
5050
*/
51-
Expr LowerGroup(const poly::ScheduleGroup& group, const std::map<std::string, Expr>& tuple_to_expr);
51+
Expr LowerGroup(const poly::ScheduleGroup& group,
52+
const std::map<std::string, Expr>& tuple_to_expr,
53+
std::map<std::string, Tensor>* global_tensor_map);
5254

5355
/**
5456
* A Computation graph node.

cinn/optim/cache_read_write_replace.cc

+10-2
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ struct CacheReplaceMutator : public ir::IRMutator<> {
7676

7777
} // namespace
7878

79-
void CacheReadWriteReplace(Expr* expr) {
79+
void CacheReadWriteReplace(Expr* expr, std::map<std::string, ir::Tensor>* global_tensor_map) {
8080
auto cached_tensors = ir::CollectIRNodes(*expr, [](const Expr* x) {
8181
auto* t = x->as_tensor();
8282
return t && (t->read_cache_relation || t->write_cache_relation);
8383
});
8484

85+
LOG(INFO) << "expr: " << *expr;
8586
auto tensors = ir::CollectIRNodes(*expr, [](const Expr* x) { return x->as_tensor(); });
8687

8788
std::set<ir::Tensor> uniq_cached_tensors;
@@ -95,9 +96,16 @@ void CacheReadWriteReplace(Expr* expr) {
9596
tensor_map[t->name] = t;
9697
}
9798

99+
// update global_tensor_map
100+
for (auto& item : tensor_map) {
101+
if (!global_tensor_map->count(item.first)) {
102+
(*global_tensor_map)[item.first] = item.second;
103+
}
104+
}
105+
98106
for (auto& t : uniq_cached_tensors) {
99107
if (t->read_cache_relation) {
100-
auto cache = tensor_map.at(t->read_cache_relation->cache_name);
108+
auto cache = global_tensor_map->at(t->read_cache_relation->cache_name);
101109
CacheReplaceMutator(t->name, cache, t->read_cache_relation->readers, true /*read*/)(expr);
102110
}
103111
if (t->write_cache_relation) {

cinn/optim/cache_read_write_replace.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#pragma once
2+
#include <string>
3+
24
#include "cinn/ir/ir.h"
35

46
namespace cinn {
57
namespace optim {
68

7-
void CacheReadWriteReplace(Expr* expr);
9+
void CacheReadWriteReplace(Expr* expr, std::map<std::string, ir::Tensor>* global_tensor_map);
810

911
} // namespace optim
1012
} // namespace cinn

cinn/optim/optimize.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ Expr Optimize(Expr e, bool runtime_debug_info) {
3333
#ifdef CINN_WITH_CUDA
3434
RemoveGpuForloopsAxis(&copied);
3535
#endif
36-
CacheReadWriteReplace(&copied);
36+
// CacheReadWriteReplace(&copied);
3737

3838
RemoveNestedBlock(&copied);
3939

cinn/poly/compute_at_transform.cc

+16-16
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ void ComputeAtTransform::AdjustPdomain() {
99

1010
isl::set cdomain1 = isl::manage(AddParamsTo(cdomain_.copy()));
1111

12-
LOG(INFO) << "ct_domain: " << ct_domain.space();
13-
LOG(INFO) << "cdomain1: " << cdomain1.space();
12+
VLOG(3) << "ct_domain: " << ct_domain.space();
13+
VLOG(3) << "cdomain1: " << cdomain1.space();
1414

1515
ct_domain = ct_domain.intersect(cdomain1);
16-
LOG(INFO) << "ct_domain: " << ct_domain;
16+
VLOG(3) << "ct_domain: " << ct_domain;
1717

1818
// get producer domain from access
1919
isl::map access_with_params = isl::manage(AddParamsTo(access_.copy()));
@@ -22,11 +22,11 @@ void ComputeAtTransform::AdjustPdomain() {
2222

2323
// intect with the original producer domain
2424
auto pdomain_params = isl::manage(AddParamsTo(pdomain_.copy()));
25-
LOG(INFO) << "pdomain: " << pdomain;
26-
LOG(INFO) << "pdomain_params: " << pdomain_params;
25+
VLOG(4) << "pdomain: " << pdomain;
26+
VLOG(4) << "pdomain_params: " << pdomain_params;
2727
adjusted_pdomain_ = isl::manage(isl_set_intersect(pdomain.release(), pdomain_params.release()));
2828
adjusted_pdomain_ = isl::manage(isl_simplify(adjusted_pdomain_.release()));
29-
LOG(INFO) << "adjusted pdomain: " << adjusted_pdomain_;
29+
VLOG(4) << "adjusted pdomain: " << adjusted_pdomain_;
3030
}
3131

3232
void ComputeAtTransform::AdjustPtransform() {
@@ -53,7 +53,7 @@ void ComputeAtTransform::AdjustPtransform() {
5353
ct_range1 = isl::manage(isl_set_set_tuple_name(ct_range1.release(), ptuple()));
5454

5555
adjusted_ptransform_ = adjusted_ptransform_.intersect_range(ct_range1);
56-
LOG(INFO) << "adjusted_ptransform: " << adjusted_ptransform_;
56+
VLOG(4) << "adjusted_ptransform: " << adjusted_ptransform_;
5757
}
5858

5959
{ // add params
@@ -86,8 +86,8 @@ isl::map ComputeAtTransform::ctransform_with_params() {
8686
}
8787

8888
void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) {
89-
LOG(INFO) << "adjusted cdomain: " << adjusted_cdomain_;
90-
LOG(INFO) << "adjusted ctransform: " << adjusted_ctransform_;
89+
VLOG(3) << "adjusted cdomain: " << adjusted_cdomain_;
90+
VLOG(3) << "adjusted ctransform: " << adjusted_ctransform_;
9191

9292
auto adjusted_ctransform = adjusted_ctransform_;
9393
auto adjusted_ptransform = adjusted_ptransform_;
@@ -101,11 +101,11 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) {
101101

102102
auto whole_domain = isl::manage(isl_union_set_from_set(adjusted_pdomain_.copy()));
103103
whole_domain = isl::manage(isl_union_set_add_set(whole_domain.release(), adjusted_cdomain_.copy()));
104-
LOG(INFO) << "whole domain: " << whole_domain;
104+
VLOG(3) << "whole domain: " << whole_domain;
105105

106106
auto whole_schedule = isl::manage(isl_union_map_from_map(adjusted_ptransform.copy()));
107107
whole_schedule = isl::manage(isl_union_map_add_map(whole_schedule.release(), adjusted_ctransform.copy()));
108-
LOG(INFO) << "whole_schedule: " << whole_schedule;
108+
VLOG(3) << "whole_schedule: " << whole_schedule;
109109

110110
isl::set context(whole_domain.ctx(), "{:}");
111111

@@ -166,7 +166,7 @@ std::string GenConsumerParamName(const char* tuple, int id) {
166166
}
167167

168168
std::vector<int> ComputeAtTransform::GetProducerAdjustedShape() const {
169-
LOG(INFO) << "domain: " << adjusted_pdomain();
169+
VLOG(3) << "domain: " << adjusted_pdomain();
170170
isl::set param_limit = isl::manage(isl_set_universe(adjusted_pdomain().space().release()));
171171
// set all the params to 0
172172
isl_local_space* local_space = isl_local_space_from_space(param_limit.space().release());
@@ -193,12 +193,12 @@ std::vector<int> ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParam
193193
std::vector<int> res;
194194

195195
isl::set cdomain_with_param = isl::manage(AddParamsTo(cdomain_.copy()));
196-
LOG(INFO) << "cdomain_with_param: " << cdomain_with_param;
196+
VLOG(4) << "cdomain_with_param: " << cdomain_with_param;
197197
isl::map access_with_param = isl::manage(AddParamsTo(access_.copy()));
198198

199-
LOG(INFO) << "*** applied: " << cdomain_with_param.apply(access_with_param);
199+
VLOG(4) << "applied: " << cdomain_with_param.apply(access_with_param);
200200
isl::set param_limited_cdomain = ctransform_with_params().domain();
201-
LOG(INFO) << "ctransform.domain: " << param_limited_cdomain;
201+
VLOG(4) << "ctransform.domain: " << param_limited_cdomain;
202202
isl::set access_domain = param_limited_cdomain.apply(access_with_param);
203203

204204
// set all the params to 0
@@ -212,7 +212,7 @@ std::vector<int> ComputeAtTransform::GetAccessesPrecedingIndicesMinAssumingParam
212212

213213
access_domain = access_domain.intersect(adjusted_pdomain());
214214

215-
LOG(INFO) << "access_with_param: " << access_domain;
215+
VLOG(3) << "access_with_param: " << access_domain;
216216

217217
for (int i = 0; i < level_ + 1; i++) {
218218
auto [minv, maxv] = isl_set_get_axis_range(access_domain.get(), i);

0 commit comments

Comments
 (0)