15
15
# repo: PaddleClas
16
16
# model: ppcls^configs^ImageNet^Distillation^resnet34_distill_resnet18_afd
17
17
# method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||method:pow||method:mean||api:paddle.nn.functional.pooling.adaptive_avg_pool2d||method:reshape||api:paddle.tensor.manipulation.stack
18
- import unittest
19
18
20
- import numpy as np
19
+ from base import * # noqa: F403
21
20
22
- import paddle
21
+ from paddle . static import InputSpec
23
22
24
23
25
24
class LayerCase (paddle .nn .Layer ):
@@ -72,8 +71,59 @@ def forward(
72
71
return var_40
73
72
74
73
75
- class TestLayer (unittest .TestCase ):
76
- def setUp (self ):
74
+ class TestLayer (TestBase ):
75
+ def init (self ):
76
+ self .input_specs = [
77
+ InputSpec (
78
+ shape = (- 1 , - 1 , - 1 , - 1 ),
79
+ dtype = paddle .float32 ,
80
+ name = None ,
81
+ stop_gradient = False ,
82
+ ),
83
+ InputSpec (
84
+ shape = (- 1 , - 1 , - 1 , - 1 ),
85
+ dtype = paddle .float32 ,
86
+ name = None ,
87
+ stop_gradient = False ,
88
+ ),
89
+ InputSpec (
90
+ shape = (- 1 , - 1 , - 1 , - 1 ),
91
+ dtype = paddle .float32 ,
92
+ name = None ,
93
+ stop_gradient = False ,
94
+ ),
95
+ InputSpec (
96
+ shape = (- 1 , - 1 , - 1 , - 1 ),
97
+ dtype = paddle .float32 ,
98
+ name = None ,
99
+ stop_gradient = False ,
100
+ ),
101
+ InputSpec (
102
+ shape = (- 1 , - 1 , - 1 , - 1 ),
103
+ dtype = paddle .float32 ,
104
+ name = None ,
105
+ stop_gradient = False ,
106
+ ),
107
+ InputSpec (
108
+ shape = (- 1 , - 1 , - 1 , - 1 ),
109
+ dtype = paddle .float32 ,
110
+ name = None ,
111
+ stop_gradient = False ,
112
+ ),
113
+ InputSpec (
114
+ shape = (- 1 , - 1 , - 1 , - 1 ),
115
+ dtype = paddle .float32 ,
116
+ name = None ,
117
+ stop_gradient = False ,
118
+ ),
119
+ InputSpec (
120
+ shape = (- 1 , - 1 , - 1 , - 1 ),
121
+ dtype = paddle .float32 ,
122
+ name = None ,
123
+ stop_gradient = False ,
124
+ ),
125
+ ]
126
+
77
127
self .inputs = (
78
128
paddle .rand (shape = [22 , 64 , 56 , 56 ], dtype = paddle .float32 ),
79
129
paddle .rand (shape = [22 , 64 , 56 , 56 ], dtype = paddle .float32 ),
@@ -86,34 +136,9 @@ def setUp(self):
86
136
)
87
137
self .net = LayerCase ()
88
138
89
- def train (self , net , to_static , with_prim = False , with_cinn = False ):
90
- if to_static :
91
- paddle .set_flags ({'FLAGS_prim_all' : with_prim })
92
- if with_cinn :
93
- build_strategy = paddle .static .BuildStrategy ()
94
- build_strategy .build_cinn_pass = True
95
- net = paddle .jit .to_static (
96
- net , build_strategy = build_strategy , full_graph = True
97
- )
98
- else :
99
- net = paddle .jit .to_static (net , full_graph = True )
100
- paddle .seed (123 )
101
- outs = net (* self .inputs )
102
- return outs
103
-
104
- def test_ast_prim_cinn (self ):
105
- st_out = self .train (self .net , to_static = True )
139
+ def set_flags (self ):
106
140
# NOTE(Aurelius84): cinn_op.pool2d only support pool_type='avg' under adaptive=True
107
141
paddle .set_flags ({"FLAGS_deny_cinn_ops" : "pool2d" })
108
- cinn_out = self .train (
109
- self .net , to_static = True , with_prim = True , with_cinn = True
110
- )
111
- # TODO(Aurelius84): It contains reduce operation and atol can't satisfy
112
- # 1e-8, so we set it to 1e-6.
113
- for st , cinn in zip (
114
- paddle .utils .flatten (st_out ), paddle .utils .flatten (cinn_out )
115
- ):
116
- np .testing .assert_allclose (st .numpy (), cinn .numpy (), atol = 1e-6 )
117
142
118
143
119
144
if __name__ == '__main__' :
0 commit comments