Skip to content

Commit d1026fc

Browse files
authored
[CINN]Support Dynamic Shape for CINN subgraph UT (#64076)
* [CINN]Support Dynamic Shape for CINN subgraph UT * add more UT * add prim flag * fix flag * fix atol * fix timeout * fix timeout * fix sub_graph_3
1 parent c63ad5d commit d1026fc

File tree

5 files changed

+186
-127
lines changed

5 files changed

+186
-127
lines changed

test/ir/pir/cinn/sub_graphs/base.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
21+
22+
class TestBase(unittest.TestCase):
23+
def setUp(self):
24+
# default setting
25+
self.net = None
26+
self.inputs = None
27+
self.input_specs = None
28+
self.with_prim = True
29+
self.with_cinn = True
30+
self.atol = 1e-6
31+
# override customized settting
32+
self.init()
33+
34+
def init(self):
35+
pass
36+
37+
def set_flags(self):
38+
pass
39+
40+
def train(self, net, to_static, with_prim=False, with_cinn=False):
41+
if to_static:
42+
paddle.set_flags({'FLAGS_prim_all': with_prim})
43+
if with_cinn:
44+
build_strategy = paddle.static.BuildStrategy()
45+
build_strategy.build_cinn_pass = True
46+
net = paddle.jit.to_static(
47+
net,
48+
build_strategy=build_strategy,
49+
full_graph=True,
50+
input_spec=self.input_specs,
51+
)
52+
else:
53+
net = paddle.jit.to_static(
54+
net, full_graph=True, input_spec=self.input_specs
55+
)
56+
paddle.seed(123)
57+
net.eval()
58+
outs = net(*self.inputs)
59+
return outs
60+
61+
def test_ast_prim_cinn(self):
62+
if not self.net:
63+
return
64+
st_out = self.train(self.net, to_static=True)
65+
self.set_flags()
66+
cinn_out = self.train(
67+
self.net,
68+
to_static=True,
69+
with_prim=self.with_prim,
70+
with_cinn=self.with_cinn,
71+
)
72+
for st, cinn in zip(
73+
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
74+
):
75+
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)

test/ir/pir/cinn/sub_graphs/test_sub_graph_0.py

+56-31
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
# repo: PaddleClas
1616
# model: ppcls^configs^ImageNet^Distillation^resnet34_distill_resnet18_afd
1717
# method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||api:paddle.tensor.manipulation.stack
18-
import unittest
1918

20-
import numpy as np
19+
from base import * # noqa: F403
2120

22-
import paddle
21+
from paddle.static import InputSpec
2322

2423

2524
class LayerCase(paddle.nn.Layer):
@@ -72,8 +71,59 @@ def forward(
7271
return var_40
7372

7473

75-
class TestLayer(unittest.TestCase):
76-
def setUp(self):
74+
class TestLayer(TestBase):
75+
def init(self):
76+
self.input_specs = [
77+
InputSpec(
78+
shape=(-1, -1, -1, -1),
79+
dtype=paddle.float32,
80+
name=None,
81+
stop_gradient=False,
82+
),
83+
InputSpec(
84+
shape=(-1, -1, -1, -1),
85+
dtype=paddle.float32,
86+
name=None,
87+
stop_gradient=False,
88+
),
89+
InputSpec(
90+
shape=(-1, -1, -1, -1),
91+
dtype=paddle.float32,
92+
name=None,
93+
stop_gradient=False,
94+
),
95+
InputSpec(
96+
shape=(-1, -1, -1, -1),
97+
dtype=paddle.float32,
98+
name=None,
99+
stop_gradient=False,
100+
),
101+
InputSpec(
102+
shape=(-1, -1, -1, -1),
103+
dtype=paddle.float32,
104+
name=None,
105+
stop_gradient=False,
106+
),
107+
InputSpec(
108+
shape=(-1, -1, -1, -1),
109+
dtype=paddle.float32,
110+
name=None,
111+
stop_gradient=False,
112+
),
113+
InputSpec(
114+
shape=(-1, -1, -1, -1),
115+
dtype=paddle.float32,
116+
name=None,
117+
stop_gradient=False,
118+
),
119+
InputSpec(
120+
shape=(-1, -1, -1, -1),
121+
dtype=paddle.float32,
122+
name=None,
123+
stop_gradient=False,
124+
),
125+
]
126+
77127
self.inputs = (
78128
paddle.rand(shape=[22, 64, 56, 56], dtype=paddle.float32),
79129
paddle.rand(shape=[22, 64, 56, 56], dtype=paddle.float32),
@@ -86,34 +136,9 @@ def setUp(self):
86136
)
87137
self.net = LayerCase()
88138

89-
def train(self, net, to_static, with_prim=False, with_cinn=False):
90-
if to_static:
91-
paddle.set_flags({'FLAGS_prim_all': with_prim})
92-
if with_cinn:
93-
build_strategy = paddle.static.BuildStrategy()
94-
build_strategy.build_cinn_pass = True
95-
net = paddle.jit.to_static(
96-
net, build_strategy=build_strategy, full_graph=True
97-
)
98-
else:
99-
net = paddle.jit.to_static(net, full_graph=True)
100-
paddle.seed(123)
101-
outs = net(*self.inputs)
102-
return outs
103-
104-
def test_ast_prim_cinn(self):
105-
st_out = self.train(self.net, to_static=True)
139+
def set_flags(self):
106140
# NOTE(Aurelius84): cinn_op.pool2d only support pool_type='avg' under adaptive=True
107141
paddle.set_flags({"FLAGS_deny_cinn_ops": "pool2d"})
108-
cinn_out = self.train(
109-
self.net, to_static=True, with_prim=True, with_cinn=True
110-
)
111-
# TODO(Aurelius84): It contains reduce operation and atol can't satisfy
112-
# 1e-8, so we set it to 1e-6.
113-
for st, cinn in zip(
114-
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
115-
):
116-
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)
117142

118143

119144
if __name__ == '__main__':

test/ir/pir/cinn/sub_graphs/test_sub_graph_1.py

+12-31
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
# repo: PaddleClas
1616
# model: ppcls^configs^ImageNet^Distillation^resnet34_distill_resnet18_afd
1717
# api:paddle.nn.functional.pooling.adaptive_avg_pool2d||api:paddle.tensor.manipulation.flatten||api:paddle.nn.functional.common.linear
18-
import unittest
18+
from base import * # noqa: F403
1919

20-
import numpy as np
21-
22-
import paddle
20+
from paddle.static import InputSpec
2321

2422

2523
class LayerCase(paddle.nn.Layer):
@@ -50,38 +48,21 @@ def forward(
5048
return var_3
5149

5250

53-
class TestLayer(unittest.TestCase):
54-
def setUp(self):
51+
class TestLayer(TestBase):
52+
def init(self):
53+
self.input_specs = [
54+
InputSpec(
55+
shape=(-1, -1, -1, -1),
56+
dtype=paddle.float32,
57+
name=None,
58+
stop_gradient=True,
59+
)
60+
]
5561
self.inputs = (
5662
paddle.rand(shape=[10, 512, 7, 7], dtype=paddle.float32),
5763
)
5864
self.net = LayerCase()
5965

60-
def train(self, net, to_static, with_prim=False, with_cinn=False):
61-
if to_static:
62-
paddle.set_flags({'FLAGS_prim_all': with_prim})
63-
if with_cinn:
64-
build_strategy = paddle.static.BuildStrategy()
65-
build_strategy.build_cinn_pass = True
66-
net = paddle.jit.to_static(
67-
net, build_strategy=build_strategy, full_graph=True
68-
)
69-
else:
70-
net = paddle.jit.to_static(net, full_graph=True)
71-
paddle.seed(123)
72-
outs = net(*self.inputs)
73-
return outs
74-
75-
def test_ast_prim_cinn(self):
76-
st_out = self.train(self.net, to_static=True)
77-
cinn_out = self.train(
78-
self.net, to_static=True, with_prim=True, with_cinn=True
79-
)
80-
for st, cinn in zip(
81-
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
82-
):
83-
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)
84-
8566

8667
if __name__ == '__main__':
8768
unittest.main()

test/ir/pir/cinn/sub_graphs/test_sub_graph_3.py

+31-34
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
# repo: PaddleClas
1616
# model: ppcls^configs^ImageNet^LeViT^LeViT_128
1717
# api:paddle.tensor.manipulation.reshape||api:paddle.tensor.linalg.transpose||api:paddle.tensor.linalg.transpose||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||method:__getitem__||api:paddle.tensor.manipulation.gather||api:paddle.tensor.manipulation.concat||api:paddle.tensor.linalg.transpose||method:reshape||api:paddle.tensor.linalg.transpose||api:paddle.tensor.linalg.matmul||method:__mul__||method:__add__||api:paddle.nn.functional.activation.softmax||api:paddle.tensor.linalg.matmul||api:paddle.tensor.linalg.transpose||api:paddle.tensor.manipulation.reshape
18-
import unittest
18+
from base import * # noqa: F403
1919

20-
import numpy as np
21-
22-
import paddle
20+
from paddle.static import InputSpec
2321

2422

2523
class LayerCase(paddle.nn.Layer):
@@ -57,8 +55,35 @@ def forward(
5755
return var_17
5856

5957

60-
class TestLayer(unittest.TestCase):
61-
def setUp(self):
58+
class TestLayer(TestBase):
59+
def init(self):
60+
# FIXME(Aurelius84): CI timout > 600 s
61+
# self.input_specs = [
62+
# InputSpec(
63+
# shape=(-1, -1, -1),
64+
# dtype=paddle.float32,
65+
# name=None,
66+
# stop_gradient=False,
67+
# ),
68+
# InputSpec(
69+
# shape=(-1, -1, -1, -1),
70+
# dtype=paddle.float32,
71+
# name=None,
72+
# stop_gradient=False,
73+
# ),
74+
# InputSpec(
75+
# shape=(-1, -1, -1, -1),
76+
# dtype=paddle.float32,
77+
# name=None,
78+
# stop_gradient=False,
79+
# ),
80+
# InputSpec(
81+
# shape=(-1, -1),
82+
# dtype=paddle.int64,
83+
# name=None,
84+
# stop_gradient=True,
85+
# ),
86+
# ]
6287
self.inputs = (
6388
paddle.rand(shape=[22, 16, 256], dtype=paddle.float32),
6489
paddle.rand(shape=[22, 16, 49, 16], dtype=paddle.float32),
@@ -67,34 +92,6 @@ def setUp(self):
6792
)
6893
self.net = LayerCase()
6994

70-
def train(self, net, to_static, with_prim=False, with_cinn=False):
71-
if to_static:
72-
paddle.set_flags({'FLAGS_prim_all': with_prim})
73-
if with_cinn:
74-
build_strategy = paddle.static.BuildStrategy()
75-
build_strategy.build_cinn_pass = True
76-
net = paddle.jit.to_static(
77-
net, build_strategy=build_strategy, full_graph=True
78-
)
79-
else:
80-
net = paddle.jit.to_static(net, full_graph=True)
81-
paddle.seed(123)
82-
outs = net(*self.inputs)
83-
return outs
84-
85-
def test_ast_prim_cinn(self):
86-
# TODO(Aurelius84): deny cinn_op.gather
87-
paddle.set_flags({"FLAGS_deny_cinn_ops": "gather"})
88-
st_out = self.train(self.net, to_static=True)
89-
cinn_out = self.train(
90-
self.net, to_static=True, with_prim=True, with_cinn=False
91-
)
92-
# TODO(Aurelius84): fix precison
93-
for st, cinn in zip(
94-
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
95-
):
96-
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1)
97-
9895

9996
if __name__ == '__main__':
10097
unittest.main()

test/ir/pir/cinn/sub_graphs/test_sub_graph_4.py

+12-31
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
# repo: PaddleClas
1616
# model: ppcls^configs^ImageNet^LeViT^LeViT_128
1717
# api:paddle.nn.functional.common.linear||method:flatten
18-
import unittest
18+
from base import * # noqa: F403
1919

20-
import numpy as np
21-
22-
import paddle
20+
from paddle.static import InputSpec
2321

2422

2523
class LayerCase(paddle.nn.Layer):
@@ -41,36 +39,19 @@ def forward(
4139
return var_2, var_1
4240

4341

44-
class TestLayer(unittest.TestCase):
45-
def setUp(self):
42+
class TestLayer(TestBase):
43+
def init(self):
44+
self.input_specs = [
45+
InputSpec(
46+
shape=(-1, -1, -1),
47+
dtype=paddle.float32,
48+
name=None,
49+
stop_gradient=False,
50+
)
51+
]
4652
self.inputs = (paddle.rand(shape=[22, 196, 128], dtype=paddle.float32),)
4753
self.net = LayerCase()
4854

49-
def train(self, net, to_static, with_prim=False, with_cinn=False):
50-
if to_static:
51-
paddle.set_flags({'FLAGS_prim_all': with_prim})
52-
if with_cinn:
53-
build_strategy = paddle.static.BuildStrategy()
54-
build_strategy.build_cinn_pass = True
55-
net = paddle.jit.to_static(
56-
net, build_strategy=build_strategy, full_graph=True
57-
)
58-
else:
59-
net = paddle.jit.to_static(net, full_graph=True)
60-
paddle.seed(123)
61-
outs = net(*self.inputs)
62-
return outs
63-
64-
def test_ast_prim_cinn(self):
65-
st_out = self.train(self.net, to_static=True)
66-
cinn_out = self.train(
67-
self.net, to_static=True, with_prim=True, with_cinn=True
68-
)
69-
for st, cinn in zip(
70-
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
71-
):
72-
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)
73-
7455

7556
if __name__ == '__main__':
7657
unittest.main()

0 commit comments

Comments
 (0)