Skip to content

Commit 68b0cf9

Browse files
authored
【CINN】refactor codegen for cinn (#55955)
* refactor codegen for cinn * add to_string to some type which can't be += with string * fix multi-thread bug caused by static var * delete dead code and comment
1 parent db96ae5 commit 68b0cf9

File tree

10 files changed

+794
-642
lines changed

10 files changed

+794
-642
lines changed

paddle/cinn/backends/codegen_c.cc

Lines changed: 306 additions & 268 deletions
Large diffs are not rendered by default.

paddle/cinn/backends/codegen_c.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ class CodeGenC : public ir::IrPrinter {
5656
void SetInlineBuiltinCodes(bool x = true) { inline_builtin_codes_ = x; }
5757

5858
protected:
59-
std::string Compile(const ir::LoweredFunc& function);
60-
std::string Compile(const ir::Buffer& buffer);
59+
void Compile(const ir::LoweredFunc& function);
6160

6261
void GenerateHeaderFile(const ir::Module& module);
6362

@@ -71,9 +70,11 @@ class CodeGenC : public ir::IrPrinter {
7170
// @}
7271

7372
void PrintFunctionDeclaration(const ir::_LoweredFunc_* op) {
74-
os() << "void " << op->name << "(";
75-
os() << "void* _args, int32_t num_args";
76-
os() << ")";
73+
str_ += "void ";
74+
str_ += op->name;
75+
str_ += "(";
76+
str_ += "void* _args, int32_t num_args";
77+
str_ += ")";
7778
}
7879

7980
void PrintShape(const std::vector<Expr>& shape,

paddle/cinn/backends/codegen_c_x86.cc

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ void CodeGenCX86::Visit(const ir::Load *op) {
3737

3838
int bits = op->type().bits() * op->type().lanes();
3939
if (SupportsAVX512() && bits == 512) {
40-
os() << "cinn_avx512_load(";
40+
str_ += "cinn_avx512_load(";
4141
PrintAbsAddr(op);
42-
os() << ")";
42+
str_ += ")";
4343
} else if (SupportsAVX256() && bits == 256) {
44-
os() << "cinn_avx256_load(";
44+
str_ += "cinn_avx256_load(";
4545
PrintAbsAddr(op);
46-
os() << ")";
46+
str_ += ")";
4747
} else {
4848
CodeGenC::Visit(op);
4949
}
@@ -57,13 +57,13 @@ void CodeGenCX86::Visit(const ir::Broadcast *op) {
5757
int bits = op->type().bits() * op->type().lanes();
5858

5959
if (SupportsAVX512() && bits == 512) {
60-
os() << "cinn_avx512_set1(";
60+
str_ += "cinn_avx512_set1(";
6161
PrintCastExpr(op->value.type().ElementOf(), op->value);
62-
os() << ")";
62+
str_ += ")";
6363
} else if (SupportsAVX256() && bits == 256) {
64-
os() << "cinn_avx256_set1(";
64+
str_ += "cinn_avx256_set1(";
6565
PrintCastExpr(op->value.type().ElementOf(), op->value);
66-
os() << ")";
66+
str_ += ")";
6767
} else {
6868
CodeGenC::Visit(op);
6969
}
@@ -77,17 +77,17 @@ void CodeGenCX86::Visit(const ir::Store *op) {
7777

7878
int bits = op->type().bits() * op->type().lanes();
7979
if (SupportsAVX512() && bits == 512) {
80-
os() << "cinn_avx512_store(";
80+
str_ += "cinn_avx512_store(";
8181
PrintAbsAddr(op);
82-
os() << ", ";
83-
Print(op->value);
84-
os() << ")";
82+
str_ += ", ";
83+
IrPrinter::Visit(op->value);
84+
str_ += ")";
8585
} else if (SupportsAVX256() && bits == 256) {
86-
os() << "cinn_avx256_store(";
86+
str_ += "cinn_avx256_store(";
8787
PrintAbsAddr(op);
88-
os() << ", ";
89-
Print(op->value);
90-
os() << ")";
88+
str_ += ", ";
89+
IrPrinter::Visit(op->value);
90+
str_ += ")";
9191
} else {
9292
CodeGenC::Visit(op);
9393
}
@@ -101,18 +101,18 @@ void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
101101
Expr value = op->type().lanes() == 1 ? *op : broadcast_n->value;
102102

103103
if (SupportsAVX512()) {
104-
os() << "cinn_avx512_set1(";
105-
Print(value);
106-
os() << ")";
104+
str_ += "cinn_avx512_set1(";
105+
IrPrinter::Visit(value);
106+
str_ += ")";
107107
} else if (SupportsAVX256()) {
108-
os() << "cinn_avx256_set1(";
109-
Print(value);
110-
os() << ")";
108+
str_ += "cinn_avx256_set1(";
109+
IrPrinter::Visit(value);
110+
str_ += ")";
111111
} else {
112112
CINN_NOT_IMPLEMENTED
113113
}
114114
} else {
115-
Print(*op);
115+
IrPrinter::Visit(*op);
116116
}
117117
}
118118

@@ -123,35 +123,41 @@ void CodeGenCX86::Visit(const ir::intrinsics::BuiltinIntrin *op) {
123123
}
124124
int bits = op->type().bits() * op->type().lanes();
125125
if (SupportsAVX512() && bits == 512) {
126-
os() << "cinn_avx512_" << op->name << "(";
126+
str_ += "cinn_avx512_";
127+
str_ += op->name;
128+
str_ += "(";
127129
if (!op->args.empty()) {
128130
for (int i = 0; i < op->args.size() - 1; i++) {
129131
PrintVecInputArgument(&op->args[i]);
130-
os() << ", ";
132+
str_ += ", ";
131133
}
132-
Print(op->args.back());
134+
IrPrinter::Visit(op->args.back());
133135
}
134-
os() << ")";
136+
str_ += ")";
135137
} else if (SupportsAVX256() && bits == 256) {
136-
os() << "cinn_avx256_" << op->name << "(";
138+
str_ += "cinn_avx256_";
139+
str_ += op->name;
140+
str_ += "(";
137141
if (!op->args.empty()) {
138142
for (int i = 0; i < op->args.size() - 1; i++) {
139143
PrintVecInputArgument(&op->args[i]);
140-
os() << ", ";
144+
str_ += ", ";
141145
}
142146
PrintVecInputArgument(&op->args.back());
143147
}
144-
os() << ")";
148+
str_ += ")";
145149
} else if (bits == 128) {
146-
os() << "cinn_avx128_" << op->name << "(";
150+
str_ += "cinn_avx128_";
151+
str_ += op->name;
152+
str_ += "(";
147153
if (!op->args.empty()) {
148154
for (int i = 0; i < op->args.size() - 1; i++) {
149155
PrintVecInputArgument(&op->args[i]);
150-
os() << ", ";
156+
str_ += ", ";
151157
}
152158
PrintVecInputArgument(&op->args.back());
153159
}
154-
os() << ")";
160+
str_ += ")";
155161
} else {
156162
CodeGenC::Visit(op);
157163
}

paddle/cinn/backends/codegen_c_x86.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,17 @@ class CodeGenCX86 : public CodeGenC {
9191

9292
template <typename Op>
9393
void PrintAbsAddr(const Op *op) {
94-
os() << op->tensor.template As<ir::_Tensor_>()->name << " + ";
94+
str_ += op->tensor.template As<ir::_Tensor_>()->name;
95+
str_ += " + ";
9596

9697
auto index = op->index();
9798
auto *ramp_n = index.template As<ir::Ramp>();
9899
if (ramp_n) {
99100
CHECK(!ramp_n->base.template As<ir::Ramp>())
100101
<< "base of a Ramp node should not be Ramp type";
101-
Print(ramp_n->base);
102+
IrPrinter::Visit(ramp_n->base);
102103
} else {
103-
Print(op->index());
104+
IrPrinter::Visit(op->index());
104105
}
105106
}
106107

@@ -125,17 +126,21 @@ void CodeGenCX86::VisitBinaryOp(const Op *op,
125126
// TODO(Superjomn) Consider support BLAS.
126127
int bits = a.type().bits() * a.type().lanes();
127128
if (SupportsAVX512() && bits == 512) {
128-
os() << "cinn_avx512_" << op_repr << "(";
129+
str_ += "cinn_avx512_";
130+
str_ += op_repr;
131+
str_ += "(";
129132
PrintVecInputArgument(&a);
130-
os() << ", ";
133+
str_ += ", ";
131134
PrintVecInputArgument(&b);
132-
os() << ")";
135+
str_ += ")";
133136
} else if (SupportsAVX256() && bits == 256) {
134-
os() << "cinn_avx256_" << op_repr << "(";
137+
str_ += "cinn_avx256_";
138+
str_ += op_repr;
139+
str_ += "(";
135140
PrintVecInputArgument(&a);
136-
os() << ", ";
141+
str_ += ", ";
137142
PrintVecInputArgument(&b);
138-
os() << ")";
143+
str_ += ")";
139144
} else {
140145
CodeGenC::Visit(op);
141146
}

0 commit comments

Comments
 (0)