Skip to content

Commit b2325e7

Browse files
authored
【CINN】Optimize use of simplify (PaddlePaddle#71321)
* optimize use of simplify * remove simplify unit loop or block
1 parent 7cf36a5 commit b2325e7

15 files changed

+198
-246
lines changed

paddle/cinn/backends/codegen_c.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ void CodeGenC::Visit(const ir::Mul *op) { IrPrinter::Visit(op); }
170170
void CodeGenC::Visit(const ir::Div *op) { IrPrinter::Visit(op); }
171171
void CodeGenC::Visit(const ir::Mod *op) {
172172
auto copied = op->b();
173-
optim::Simplify(&copied);
173+
copied = optim::ArithSimplify(copied);
174174
if (copied.is_constant()) {
175175
int temp = static_cast<int>(copied.get_constant());
176176
if ((temp & (temp - 1)) == 0) {
@@ -891,7 +891,7 @@ void CodeGenC::Visit(const ir::_LoweredFunc_ *op) {
891891

892892
Expr func_body = ir::Block::Make(new_body);
893893

894-
optim::SimplifyBlocks(&func_body);
894+
optim::SimplifyUnitBlock(&func_body);
895895

896896
IrPrinter::Visit(func_body);
897897
}

paddle/cinn/backends/codegen_gpu_dev.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) {
183183
ir::stmt::BlockRef func_body_block = ir::stmt::BlockRef(new_body_stmts);
184184

185185
// Use ir_simplify when pass updated.
186-
// optim::SimplifyBlocks(&func_body);
186+
// optim::SimplifyUnitBlock(&func_body);
187187
// // Make sure that the function's body is wrapped by a block
188188
// if (!func_body.As<ir::Block>()) {
189189
// func_body = ir::Block::Make({func_body});
@@ -320,7 +320,7 @@ void CodeGenGpuDev::PrintTempBufferCreation(const ir::Buffer &buffer) {
320320
for (int i = 0; i < buffer->shape.size(); i++) {
321321
buffer_size = buffer_size * buffer->shape[i];
322322
}
323-
optim::Simplify(&buffer_size);
323+
buffer_size = optim::ArithSimplify(buffer_size);
324324
bool has_symbolic_constant = false;
325325
ir::ir_utils::CollectIRNodes(buffer_size, [&](const Expr *x) {
326326
if (x->as_var()) {
@@ -352,7 +352,7 @@ void CodeGenGpuDev::PrintTempBufferCreation(const ir::Buffer &buffer) {
352352
int type_bytes = buffer->dtype.bytes();
353353
dyn_shared_mem_offset_ =
354354
dyn_shared_mem_offset_ + buffer_size * Expr(type_bytes);
355-
optim::Simplify(&dyn_shared_mem_offset_);
355+
dyn_shared_mem_offset_ = optim::ArithSimplify(dyn_shared_mem_offset_);
356356
VLOG(6) << "dyn_shared_mem_offset_ = " << dyn_shared_mem_offset_;
357357
} else if (buffer->memory_type == ir::MemoryType::GPULocal) {
358358
// print func of static allocation

paddle/cinn/backends/sycl/codegen_sycl_dev.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ void CodeGenSyclDevice::PrintFunctionBody(const ir::_LoweredFunc_ *op) {
194194
APPEND_TO_NEW_BODY_STMTS(dealloc_temp_buffer_stmts);
195195
ir::stmt::BlockRef func_body_block = ir::stmt::BlockRef(new_body_stmts);
196196
// Use ir_simplify when pass updated.
197-
// optim::SimplifyBlocks(&func_body);
197+
// optim::SimplifyUnitBlock(&func_body);
198198
// // Make sure that the function's body is wrapped by a block
199199
// if (!func_body.As<ir::Block>()) {
200200
// func_body = ir::Block::Make({func_body});
@@ -253,7 +253,7 @@ void CodeGenSyclDevice::PrintTempBufferCreation(const ir::Buffer &buffer) {
253253
for (int i = 0; i < buffer->shape.size(); i++) {
254254
buffer_size = buffer_size * buffer->shape[i];
255255
}
256-
optim::Simplify(&buffer_size);
256+
buffer_size = optim::ArithSimplify(buffer_size);
257257
IrPrinter::Visit(buffer_size);
258258
str_ += " ]";
259259
};
@@ -268,7 +268,7 @@ void CodeGenSyclDevice::PrintTempBufferCreation(const ir::Buffer &buffer) {
268268
for (int i = 0; i < buffer->shape.size(); i++) {
269269
buffer_size = buffer_size * buffer->shape[i];
270270
}
271-
optim::Simplify(&buffer_size);
271+
buffer_size = optim::ArithSimplify(buffer_size);
272272
IrPrinter::Visit(buffer_size);
273273
str_ += " ]>(item.get_group())";
274274
break;

paddle/cinn/hlir/pe/ir_schedule_pe.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1300,9 +1300,9 @@ void IRCudaScheduleConv(ir::IRSchedule &ir_sch, // NOLINT
13001300

13011301
int n = output->shape[0].as_int32();
13021302
int c = output->shape[1].as_int32();
1303-
optim::Simplify(&(output->shape[2]));
1303+
output->shape[2] = optim::ArithSimplify(output->shape[2]);
13041304
int h = output->shape[2].as_int32();
1305-
optim::Simplify(&(output->shape[3]));
1305+
output->shape[3] = optim::ArithSimplify(output->shape[3]);
13061306
int w = output->shape[3].as_int32();
13071307
int rc = input_pad->shape[1].as_int32();
13081308

@@ -1480,8 +1480,8 @@ void IRCudaScheduleConv2(ir::IRSchedule &ir_sch, // NOLINT
14801480

14811481
// stages[input_pad]->ComputeInline();
14821482

1483-
optim::Simplify(&(output->shape[2]));
1484-
optim::Simplify(&(output->shape[3]));
1483+
output->shape[2] = optim::ArithSimplify(output->shape[2]);
1484+
output->shape[3] = optim::ArithSimplify(output->shape[3]);
14851485

14861486
VLOG(3) << "Begin IRCudaScheduleConv2 with expr : "
14871487
<< ir_sch.GetModule().GetExprs().at(0);

paddle/cinn/ir/group_schedule/config/group_tile_util.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ bool CheckTensorIsBroadcastAndContinuous(
206206
bool is_broadcast = false;
207207
for (int i = 0; i < indices.size(); ++i) {
208208
ir::Expr index = indices[i];
209-
cinn::optim::Simplify(&index);
209+
index = optim::ArithSimplify(index);
210210
if (index.is_constant() && index.get_constant() == 0) {
211211
is_broadcast = true;
212212
continue;
@@ -244,7 +244,7 @@ bool CheckTensorIsContinuous(
244244
const std::unordered_map<ir::Var, ir::Expr>& iter_var2value) {
245245
for (int i = 0; i < indices.size(); ++i) {
246246
ir::Expr index = indices[i];
247-
cinn::optim::Simplify(&index);
247+
index = optim::ArithSimplify(index);
248248
if (index.is_constant()) return false;
249249
if (!index.is_var()) return false;
250250
ir::Var iter_var = index.as_var_ref();

paddle/cinn/ir/ir_base.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,12 @@ enum class StmtNodeTy { kUnk = -1, NODETY_FORALL_STMT(__m) };
177177
//! String representations for IrNodeTy.
178178
// @{
179179
#define __m(x__) #x__,
180-
const std::vector<std::string> kIrNodeTyReprs(
181-
{NODETY_FORALL(__m) "IterSplit", "IterSum", "IterMark", "None"});
180+
const std::vector<std::string> kIrNodeTyReprs({"Module",
181+
"LoweredFunc",
182+
"IterSplit",
183+
"IterSum",
184+
"IterMark",
185+
NODETY_FORALL(__m)});
182186
#undef __m
183187
// @}
184188

paddle/cinn/lang/compute.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,12 @@ ir::Tensor Compute(const std::vector<Expr> &domain,
178178

179179
// construct the shape.
180180
for (auto dim : domain) {
181-
auto copied = dim;
182-
optim::Simplify(&copied);
181+
auto copied = optim::ArithSimplify(dim);
183182
domain_without_reduce_axis.push_back(copied);
184183
}
185184

186185
for (auto dim : shape) {
187-
auto copied = dim;
188-
optim::Simplify(&copied);
186+
auto copied = optim::ArithSimplify(dim);
189187
shape_simplified.push_back(copied);
190188
}
191189

paddle/cinn/lang/lower_impl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {
384384

385385
if (support_ir_schedule_) {
386386
optim::TransformPolyForToFor(&func->body);
387-
optim::SimplifyBlocks(&func->body);
387+
optim::SimplifyUnitBlock(&func->body);
388388
func->body = ir::Block::Make({func->body});
389389
result.push_back(func);
390390
num_func++;

0 commit comments

Comments
 (0)