Skip to content

Commit eaa6166

Browse files
authored
update static list and open train for some subgraphs (PaddlePaddle#65489)
1 parent eb7db8f commit eaa6166

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+40
-89
lines changed

test/ir/pir/cinn/sub_graphs/CMakeLists.txt

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,31 @@
11
if(WITH_GPU)
22
set(STATIC_BUILD_TESTS
3-
test_sub_graph_0
4-
test_sub_graph_1
5-
test_sub_graph_2
6-
test_sub_graph_3
7-
test_sub_graph_5
8-
test_sub_graph_10
93
test_sub_graph_12
10-
test_sub_graph_13
11-
test_sub_graph_16
12-
test_sub_graph_17
13-
test_sub_graph_18
14-
test_sub_graph_22
154
test_sub_graph_23
16-
test_sub_graph_24
17-
test_sub_graph_25
18-
test_sub_graph_26
19-
test_sub_graph_27
20-
test_sub_graph_28
21-
test_sub_graph_29
22-
test_sub_graph_30
23-
test_sub_graph_31
24-
test_sub_graph_33
255
test_sub_graph_34
26-
test_sub_graph_40
276
test_sub_graph_42
287
test_sub_graph_43
29-
test_sub_graph_44
308
test_sub_graph_48
319
test_sub_graph_49
3210
test_sub_graph_add_n
3311
test_sub_graph_add
34-
test_sub_graph_avg_pool2d
3512
test_sub_graph_chunk
36-
test_sub_graph_max_pool2d
37-
test_sub_graph_reshape
38-
test_sub_graph_swish
3913
test_sub_graph_mul_method
40-
test_sub_graph_adaptive_avg_pool2d
14+
test_sub_graph_squeeze_unsqueeze
4115
test_sub_graph_52
42-
test_sub_graph_55
4316
test_sub_graph_56
44-
test_sub_graph_59
4517
test_sub_graph_61
4618
test_sub_graph_62
47-
test_sub_graph_63
48-
test_sub_graph_66
4919
test_sub_graph_67
5020
test_sub_graph_69
5121
test_sub_graph_73
52-
test_sub_graph_74
5322
test_sub_graph_77
5423
test_sub_graph_79
5524
test_sub_graph_80
5625
test_sub_graph_82
57-
test_sub_graph_84
26+
test_sub_graph_85
5827
test_sub_graph_86
59-
test_sub_graph_87
60-
test_sub_graph_88)
28+
test_sub_graph_87)
6129

6230
file(
6331
GLOB DYNAMIC_BUILD_TESTS

test/ir/pir/cinn/sub_graphs/base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,29 @@ def test_ast_prim_cinn(self):
9696
st.numpy(), cinn.numpy(), atol=self.atol
9797
)
9898
if self.with_train:
99-
st_loss = st_out.mean()
99+
if type(st_out) == tuple:
100+
st_loss, cinn_loss = 0, 0
101+
for i in range(len(st_out)):
102+
st_loss += st_out[i].mean()
103+
cinn_loss += cinn_out[i].mean()
104+
else:
105+
st_loss = st_out.mean()
106+
cinn_loss = cinn_out.mean()
100107
st_loss.backward()
101108
st_grad = []
102109
for i in range(len(st_inputs)):
103-
if st_inputs[i].dtype != paddle.int64:
110+
if (
111+
st_inputs[i].dtype != paddle.int64
112+
and st_inputs[i].grad is not None
113+
):
104114
st_grad.append(st_inputs[i].grad.numpy().copy())
105-
cinn_loss = cinn_out.mean()
106115
cinn_loss.backward()
107116
cinn_grad = []
108117
for i in range(len(cinn_inputs)):
109-
if cinn_inputs[i].dtype != paddle.int64:
118+
if (
119+
cinn_inputs[i].dtype != paddle.int64
120+
and cinn_inputs[i].grad is not None
121+
):
110122
cinn_grad.append(cinn_inputs[i].grad.numpy().copy())
111123
for i in range(len(cinn_grad)):
112124
np.testing.assert_allclose(

test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def init(self):
135135
paddle.rand(shape=[22, 512, 7, 7], dtype=paddle.float32),
136136
)
137137
self.net = LayerCase
138+
self.with_train = False
138139

139140
def set_flags(self):
140141
# NOTE(Aurelius84): cinn_op.pool2d only support pool_type='avg' under adaptive=True

test/ir/pir/cinn/sub_graphs/test_sub_graph_14.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ def init(self):
7272
paddle.rand(shape=[22, 128, 56, 56], dtype=paddle.float32),
7373
)
7474
self.net = LayerCase
75-
self.with_train = False
7675

7776

7877
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_16.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def init(self):
115115
paddle.rand(shape=[22, 28, 56, 56], dtype=paddle.float32),
116116
)
117117
self.net = LayerCase
118-
self.with_train = False
119118

120119
def set_flags(self):
121120
# NOTE(Aurelius84): cinn_op.pool2d only support pool_type='avg' under adaptive=True

test/ir/pir/cinn/sub_graphs/test_sub_graph_18.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def init(self):
6868
paddle.rand(shape=[22, 1536, 8, 8], dtype=paddle.float32),
6969
)
7070
self.net = LayerCase
71-
self.with_train = False
7271
self.with_precision_compare = False
72+
self.with_train = False
7373

7474
# NOTE output mismatch with prim
7575

test/ir/pir/cinn/sub_graphs/test_sub_graph_20.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def init(self):
7777
paddle.rand(shape=[86, 192], dtype=paddle.float32),
7878
)
7979
self.net = LayerCase
80-
self.with_train = False
8180

8281

8382
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_25.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def init(self):
6868
paddle.rand(shape=[11, 1280, 7, 7], dtype=paddle.float32),
6969
)
7070
self.net = LayerCase
71-
self.with_train = False
7271
self.with_precision_compare = False
72+
self.with_train = False
7373

7474

7575
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_26.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def init(self):
9393
)
9494
self.net = LayerCase
9595
self.atol = 1e-5
96+
self.with_train = False
9697

9798

9899
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_27.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def init(self):
6464
)
6565
self.net = LayerCase
6666
self.with_precision_compare = False
67+
self.with_train = False
6768

6869
# NOTE prim + cinn lead to error
6970

test/ir/pir/cinn/sub_graphs/test_sub_graph_28.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def init(self):
6868
paddle.rand(shape=[10, 320, 8, 8], dtype=paddle.float32),
6969
)
7070
self.net = LayerCase
71-
self.with_train = False
7271
self.with_precision_compare = False
72+
self.with_train = False
7373

7474
# NOTE prim + cinn lead to error
7575

test/ir/pir/cinn/sub_graphs/test_sub_graph_29.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def init(self):
6868
paddle.rand(shape=[10, 2048, 10, 10], dtype=paddle.float32),
6969
)
7070
self.net = LayerCase
71-
self.with_train = False
7271
self.with_precision_compare = False
72+
self.with_train = False
7373

7474

7575
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_30.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,7 @@ def init(self):
694694
)
695695
self.net = LayerCase
696696
self.atol = 1e-1
697+
self.with_train = False
697698

698699
# NOTE prim + cinn lead to error
699700

test/ir/pir/cinn/sub_graphs/test_sub_graph_31.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def init(self):
6666
paddle.rand(shape=[22, 288, 14, 14], dtype=paddle.float32),
6767
)
6868
self.net = LayerCase
69-
self.with_train = False
7069
self.atol = 1e-8
70+
self.with_train = False
7171

7272

7373
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_33.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def init(self):
8484
paddle.rand(shape=[10, 256, 14, 14], dtype=paddle.float32),
8585
)
8686
self.net = LayerCase
87-
self.with_train = False
8887
self.atol = 1e-5
88+
self.with_train = False
8989

9090

9191
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_34.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def init(self):
5757
paddle.rand(shape=[10, 32, 56, 56], dtype=paddle.float32),
5858
)
5959
self.net = LayerCase
60-
self.with_train = False
6160
self.atol = 1e-8
61+
self.with_train = False
6262

6363

6464
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_43.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ def init(self):
258258
paddle.rand(shape=[1, 2048, 24, 36], dtype=paddle.float32),
259259
)
260260
self.net = LayerCase
261-
self.with_train = False
262261
self.atol = 1e-5
263262

264263

test/ir/pir/cinn/sub_graphs/test_sub_graph_44.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def init(self):
143143
paddle.rand(shape=[1, 100, 256], dtype=paddle.float32),
144144
)
145145
self.net = LayerCase
146-
self.with_train = False
147146
self.atol = 1e-8
148147
self.with_cinn = False
149148

test/ir/pir/cinn/sub_graphs/test_sub_graph_48.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def init(self):
190190
paddle.rand(shape=[1, 625, 1], dtype=paddle.float32),
191191
)
192192
self.net = LayerCase
193-
self.with_train = False
194193
self.atol = 1e-5
194+
self.with_train = False
195195

196196

197197
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_49.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def init(self):
6666
paddle.rand(shape=[1], dtype=paddle.float32),
6767
)
6868
self.net = LayerCase
69-
self.with_train = False
7069

7170

7271
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_54.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def init(self):
117117
paddle.rand(shape=[1, 96, 128, 128], dtype=paddle.float32),
118118
)
119119
self.net = LayerCase
120-
self.with_train = False
121120

122121

123122
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_56.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def init(self):
7474
paddle.rand(shape=[24], dtype=paddle.float32),
7575
)
7676
self.net = LayerCase
77-
self.with_train = False
7877

7978
# NOTE prim + cinn lead to error
8079

test/ir/pir/cinn/sub_graphs/test_sub_graph_57.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def init(self):
4242
self.input_specs = []
4343
self.inputs = ()
4444
self.net = LayerCase
45-
self.with_train = False
4645

4746

4847
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_61.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def init(self):
9191
paddle.rand(shape=[1, 4], dtype=paddle.float32),
9292
)
9393
self.net = LayerCase
94-
self.with_train = False
9594
self.atol = 1e-5
9695

9796

test/ir/pir/cinn/sub_graphs/test_sub_graph_62.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def init(self):
7171
paddle.rand(shape=[1], dtype=paddle.float32),
7272
)
7373
self.net = LayerCase
74-
self.with_train = False
7574
self.atol = 1e-8
7675

7776

test/ir/pir/cinn/sub_graphs/test_sub_graph_65.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def init(self):
5555
]
5656
self.inputs = (paddle.rand(shape=[2, 2002], dtype=paddle.float32),)
5757
self.net = LayerCase
58-
self.with_train = False
5958

6059

6160
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_66.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def init(self):
6464
]
6565
self.inputs = (paddle.rand(shape=[2, 1788], dtype=paddle.float32),)
6666
self.net = LayerCase
67-
self.with_train = False
6867

6968
# NOTE prim + cinn lead to error
7069

test/ir/pir/cinn/sub_graphs/test_sub_graph_67.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def init(self):
134134
paddle.randint(low=0, high=10, shape=[1], dtype=paddle.int32),
135135
)
136136
self.net = LayerCase
137-
self.with_train = False
138137
self.with_cinn = False
138+
self.with_train = False
139139

140140
# NOTE prim + cinn lead to error
141141

test/ir/pir/cinn/sub_graphs/test_sub_graph_69.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def init(self):
6565
paddle.rand(shape=[1, 171888, 4], dtype=paddle.float32),
6666
)
6767
self.net = LayerCase
68-
self.with_train = False
6968
self.with_precision_compare = False
69+
self.with_train = False
7070

7171

7272
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_7.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def init(self):
8080
paddle.randint(low=0, high=10, shape=[49, 49], dtype=paddle.int64),
8181
)
8282
self.net = LayerCase
83-
self.with_train = False
8483
self.with_cinn = False
84+
self.with_train = False
8585

8686
# NOTE prim + cinn lead to error
8787
# NOTE can not pass when atol=1e-8 with prim

test/ir/pir/cinn/sub_graphs/test_sub_graph_72.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def init(self):
143143
self.input_specs = []
144144
self.inputs = ()
145145
self.net = LayerCase
146-
self.with_train = False
147146

148147

149148
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_73.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def init(self):
9898
paddle.rand(shape=[2], dtype=paddle.float32),
9999
)
100100
self.net = LayerCase
101-
self.with_train = False
102101
self.with_cinn = False
102+
self.with_train = False
103103

104104

105105
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_74.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def init(self):
7575
paddle.rand(shape=[2], dtype=paddle.float32),
7676
)
7777
self.net = LayerCase
78-
self.with_train = False
7978
self.with_precision_compare = False
79+
self.with_train = False
8080

8181

8282
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_77.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def init(self):
209209
paddle.rand(shape=[1], dtype=paddle.float32),
210210
)
211211
self.net = LayerCase
212-
self.with_train = False
213212

214213

215214
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_80.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def init(self):
125125
paddle.rand(shape=[1, 3, 48, 48, 1], dtype=paddle.float32),
126126
)
127127
self.net = LayerCase
128-
self.with_train = False
129128

130129

131130
if __name__ == '__main__':

0 commit comments

Comments
 (0)