Skip to content

Commit 422c5d4

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_randomness_in_fusion
2 parents 0a31620 + 21446e9 commit 422c5d4

File tree

87 files changed

+2817
-242
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+2817
-242
lines changed

_typos.toml

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ grad = "grad"
1010
arange = "arange"
1111
ot = 'ot'
1212
pash = 'pash'
13+
eles = 'eles'
1314

1415
# These words need to be fixed
1516
ontext = 'ontext'
@@ -25,7 +26,6 @@ expection = 'expection'
2526
pacakage = 'pacakage'
2627
normlized = 'normlized'
2728
Runing = 'Runing'
28-
Ags = 'Ags'
2929
Becasue = 'Becasue'
3030
craeted = 'craeted'
3131
oreder = 'oreder'
@@ -154,7 +154,6 @@ CACH = 'CACH'
154154
endianess = 'endianess'
155155
VAILD = 'VAILD'
156156
ues = 'ues'
157-
algorithem = 'algorithem'
158157
aer = 'aer'
159158
elemenents = 'elemenents'
160159
CANN = 'CANN'
@@ -181,7 +180,6 @@ overrided = 'overrided'
181180
smll = 'smll'
182181
outpout = 'outpout'
183182
staticaly = 'staticaly'
184-
aranged = 'aranged'
185183
offets = 'offets'
186184
olny = 'olny'
187185
Continer = 'Continer'
@@ -192,7 +190,6 @@ readed = 'readed'
192190
Opeartion = 'Opeartion'
193191
shoule = 'shoule'
194192
inputed = 'inputed'
195-
arrray = 'arrray'
196193
positon = 'positon'
197194
invalide = 'invalide'
198195
repeatly = 'repeatly'
@@ -255,16 +252,13 @@ defition = 'defition'
255252
operants = 'operants'
256253
funcitons = 'funcitons'
257254
dateset = 'dateset'
258-
arised = 'arised'
259255
optimzed = 'optimzed'
260256
encouter = 'encouter'
261-
alis = 'alis'
262257
feeded = 'feeded'
263258
poped = 'poped'
264259
parmeter = 'parmeter'
265260
doens = 'doens'
266261
cadidate = 'cadidate'
267-
argumnets = 'argumnets'
268262
inconsistence = 'inconsistence'
269263
Caculate = 'Caculate'
270264
seperator = 'seperator'
@@ -358,7 +352,6 @@ accessable = 'accessable'
358352
Wheter = 'Wheter'
359353
processer = 'processer'
360354
recored = 'recored'
361-
afer = 'afer'
362355
wraper = 'wraper'
363356
Partitial = 'Partitial'
364357
leafs = 'leafs'
@@ -417,7 +410,6 @@ recieved = 'recieved'
417410
Hanlder = 'Hanlder'
418411
EPOCHES = 'EPOCHES'
419412
sequnce = 'sequnce'
420-
argmuments = 'argmuments'
421413
Iteraion = 'Iteraion'
422414
whill = 'whill'
423415
tood = 'tood'
@@ -478,7 +470,6 @@ subsituted = 'subsituted'
478470
automaticly = 'automaticly'
479471
Minium = 'Minium'
480472
sequnece = 'sequnece'
481-
adjancent = 'adjancent'
482473
payed = 'payed'
483474
linke = 'linke'
484475
nagative = 'nagative'
@@ -556,7 +547,6 @@ splited = 'splited'
556547
instrinsics = 'instrinsics'
557548
outputing = 'outputing'
558549
hadler = 'hadler'
559-
aggragate = 'aggragate'
560550
qucik = 'qucik'
561551
alog = 'alog'
562552
exsit = 'exsit'
@@ -664,7 +654,6 @@ usefull = 'usefull'
664654
sqaure = 'sqaure'
665655
adn = 'adn'
666656
intialize = 'intialize'
667-
addtional = 'addtional'
668657
Taget = 'Taget'
669658
parm = 'parm'
670659
thrads = 'thrads'
@@ -686,7 +675,6 @@ upsupported = 'upsupported'
686675
settting = 'settting'
687676
templat = 'templat'
688677
priorites = 'priorites'
689-
admissable = 'admissable'
690678
optmize = 'optmize'
691679
bondary = 'bondary'
692680
Traning = 'Traning'
@@ -732,7 +720,6 @@ mannualy = 'mannualy'
732720
learing = 'learing'
733721
noraml = 'noraml'
734722
padd = 'padd'
735-
alogrithm = 'alogrithm'
736723
Requred = 'Requred'
737724
astroid = 'astroid'
738725
Fidn = 'Fidn'
@@ -778,14 +765,12 @@ expaned = 'expaned'
778765
choos = 'choos'
779766
whos = 'whos'
780767
architecuture = 'architecuture'
781-
argumet = 'argumet'
782768
coule = 'coule'
783769
instanciate = 'instanciate'
784770
distrubuted = 'distrubuted'
785771
Localy = 'Localy'
786772
PARM = 'PARM'
787773
thi = 'thi'
788-
addess = 'addess'
789774
Oll = 'Oll'
790775
Auxillary = 'Auxillary'
791776
Infor = 'Infor'

paddle/cinn/backends/codegen_gpu_dev.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) {
151151
auto axis_range_assumptions = op->PrepareAxisRangeAssumptions();
152152
auto alloca_temp_buffers = op->PrepareAllocTempBufferExprs();
153153
auto temp_buffer_alias = GenerateBufferAliasExprs(op, op->temp_bufs);
154-
auto alis_var_exprs = op->CudaAliasVarExprs();
154+
auto alias_var_exprs = op->CudaAliasVarExprs();
155155
auto dealloc_temp_buffers =
156156
FilterDeallocTempBuffers(op->PrepareDeallocTempBufferExprs());
157157

@@ -160,7 +160,7 @@ void CodeGenGpuDev::Visit(const ir::_LoweredFunc_ *op) {
160160
APPEND_TO_NEW_BODY(axis_range_assumptions)
161161
APPEND_TO_NEW_BODY(alloca_temp_buffers)
162162
APPEND_TO_NEW_BODY(temp_buffer_alias)
163-
APPEND_TO_NEW_BODY(alis_var_exprs)
163+
APPEND_TO_NEW_BODY(alias_var_exprs)
164164

165165
new_body.push_back(op->body);
166166
APPEND_TO_NEW_BODY(dealloc_temp_buffers);

paddle/cinn/common/object.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,42 @@ struct Object {
7575
using object_ptr = Object*;
7676
using shared_object = Shared<Object>;
7777

78+
/*
79+
* \brief Delete the default copy/move constructor and assign operator
80+
* \param TypeName The class typename.
81+
*/
82+
#define CINN_DELETE_COPY_MOVE_AND_ASSIGN_FOR_OBJECT_NODE(TypeName) \
83+
TypeName(const TypeName& other) = delete; \
84+
TypeName(TypeName&& other) = delete; \
85+
TypeName& operator=(const TypeName& other) = delete; \
86+
TypeName& operator=(TypeName&& other) = delete;
87+
88+
/*
89+
* \brief Define the default copy/move constructor and assign operator
90+
* \param TypeName The class typename.
91+
*/
92+
#define CINN_DEFINE_COPY_MOVE_AND_ASSIGN_FOR_OBJECT_REF(TypeName) \
93+
TypeName(const TypeName& other) = default; \
94+
TypeName(TypeName&& other) = default; \
95+
TypeName& operator=(const TypeName& other) = default; \
96+
TypeName& operator=(TypeName&& other) = default;
97+
98+
/*
99+
* \brief Define object reference methods
100+
* \param TypeName The object type name
101+
* \param ParentType The parent type of the objectref
102+
* \param ObjectName The type name of the object
103+
*/
104+
#define CINN_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
105+
TypeName() = default; \
106+
explicit TypeName(ObjectName* n) : ParentType(n) {} \
107+
explicit TypeName(const cinn::common::Shared<Object>& ref) \
108+
: ParentType(ref) {} \
109+
CINN_DEFINE_COPY_MOVE_AND_ASSIGN_FOR_OBJECT_REF(TypeName); \
110+
const ObjectName* operator->() const { \
111+
return static_cast<const ObjectName*>(this->p_); \
112+
} \
113+
ObjectName* operator->() { return static_cast<ObjectName*>(this->p_); }
114+
78115
} // namespace common
79116
} // namespace cinn

paddle/cinn/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ gather_srcs(
1919
intrinsic_ops.cc
2020
layout.cc
2121
schedule_block_graph.cc
22+
stmt.cc
23+
stmt_visitors.cc
2224
dim.cc)
2325

2426
add_subdirectory(ir_analyzer)

paddle/cinn/ir/ir_base.cc

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ std::ostream &operator<<(std::ostream &os, IrNodeTy type) {
6666
return os;
6767
}
6868

69+
std::ostream &operator<<(std::ostream &os, StmtNodeTy type) {
70+
switch (type) {
71+
#define __m(t__) \
72+
case StmtNodeTy::t__: \
73+
os << "<stmt node: " << #t__ << ">"; \
74+
break;
75+
76+
NODETY_FORALL_STMT(__m)
77+
#undef __m
78+
default:
79+
PADDLE_THROW(
80+
::common::errors::InvalidArgument("unknown StmtNodeTy found"));
81+
}
82+
}
83+
6984
Expr Zero(const Type &type) {
7085
if (type.is_bfloat16()) return Expr(bfloat16(0.f));
7186
if (type.is_float16()) return Expr(float16(0.f));
@@ -330,15 +345,18 @@ void IrNode::replace(Expr old_op, Expr new_op) {
330345
}
331346

332347
void IrNode::convert_int32_to_int64() {
333-
if (type_ != Int(64))
334-
if (type_ != Int(32))
348+
if (type_ != Int(64) && type_ != UInt(64))
349+
if (type_ != Int(32) && type_ != UInt(32))
335350
PADDLE_ENFORCE_EQ(type_.is_unk(),
336351
true,
337352
::common::errors::InvalidArgument(
338353
"Current only support convert int32_t "
339354
"to int64_t, but get type is: %s",
340355
type_));
341-
type_ = Int(64);
356+
357+
if (type_ == Int(32)) type_ = Int(64);
358+
if (type_ == UInt(32)) type_ = UInt(64);
359+
342360
for (Expr &operand : operands) {
343361
operand->convert_int32_to_int64();
344362
}

paddle/cinn/ir/ir_base.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class Dim;
116116
#define NODETY_CONTROL_OP_FOR_INTRINSIC(macro__) \
117117
macro__(IntrinsicOp) \
118118

119+
// TODO(Hongqing-work): change NODETY_FORALL to NODETY_FORALL_EXPR
119120
#define NODETY_FORALL(__m) \
120121
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
121122
NODETY_OP_FOR_EACH(__m) \
@@ -126,6 +127,16 @@ class Dim;
126127
NODETY_PRIMITIVE_TYPE_FOR_EACH(__m) \
127128
NODETY_OP_FOR_EACH(__m) \
128129
NODETY_CONTROL_OP_FOR_EACH(__m)
130+
131+
#define NODETY_FORALL_STMT(macro__) \
132+
macro__(Let) \
133+
macro__(Store) \
134+
macro__(Alloc) \
135+
macro__(Free) \
136+
macro__(IfThenElse) \
137+
macro__(For) \
138+
macro__(Schedule) \
139+
macro__(Evaluate)
129140
// clang-format on
130141

131142
//! Define IrNodeTy
@@ -143,6 +154,13 @@ enum class IrNodeTy {
143154
#undef __m
144155
// @}
145156

157+
//! Define StmtNodeTy
158+
// @{
159+
#define __m(x__) x__,
160+
enum class StmtNodeTy { kUnk = -1, NODETY_FORALL_STMT(__m) };
161+
#undef __m
162+
// @}
163+
146164
//! String representations for IrNodeTy.
147165
// @{
148166
#define __m(x__) #x__,
@@ -152,6 +170,7 @@ const std::vector<std::string> kIrNodeTyReprs(
152170
// @}
153171

154172
std::ostream& operator<<(std::ostream& os, IrNodeTy type);
173+
std::ostream& operator<<(std::ostream& os, StmtNodeTy type);
155174

156175
struct Expr;
157176

0 commit comments

Comments
 (0)