28
28
29
29
# Correct: General.
30
30
class TestSqueezeOp (OpTest ):
31
-
32
31
def setUp (self ):
33
32
self .op_type = "squeeze2"
34
33
self .python_api = paddle .squeeze
@@ -40,7 +39,7 @@ def setUp(self):
40
39
self .init_attrs ()
41
40
self .outputs = {
42
41
"Out" : self .inputs ["X" ].reshape (self .new_shape ),
43
- "XShape" : np .random .random (self .ori_shape ).astype ("float64" )
42
+ "XShape" : np .random .random (self .ori_shape ).astype ("float64" ),
44
43
}
45
44
46
45
def test_check_output (self ):
@@ -60,7 +59,6 @@ def init_attrs(self):
60
59
61
60
# Correct: There is mins axis.
62
61
class TestSqueezeOp1 (TestSqueezeOp ):
63
-
64
62
def init_test_case (self ):
65
63
self .ori_shape = (1 , 20 , 1 , 5 )
66
64
self .axes = (0 , - 2 )
@@ -69,7 +67,6 @@ def init_test_case(self):
69
67
70
68
# Correct: No axes input.
71
69
class TestSqueezeOp2 (TestSqueezeOp ):
72
-
73
70
def init_test_case (self ):
74
71
self .ori_shape = (1 , 20 , 1 , 5 )
75
72
self .axes = ()
@@ -78,15 +75,13 @@ def init_test_case(self):
78
75
79
76
# Correct: Just part of axes be squeezed.
80
77
class TestSqueezeOp3 (TestSqueezeOp ):
81
-
82
78
def init_test_case (self ):
83
79
self .ori_shape = (6 , 1 , 5 , 1 , 4 , 1 )
84
80
self .axes = (1 , - 1 )
85
81
self .new_shape = (6 , 5 , 1 , 4 )
86
82
87
83
88
84
class TestSqueeze2AxesTensor (UnittestBase ):
89
-
90
85
def init_info (self ):
91
86
self .shapes = [[2 , 3 , 4 ]]
92
87
self .save_path = os .path .join (self .temp_dir .name , 'squeeze_tensor' )
@@ -123,7 +118,6 @@ def test_static(self):
123
118
124
119
125
120
class TestSqueeze2AxesTensorList (UnittestBase ):
126
-
127
121
def init_info (self ):
128
122
self .shapes = [[2 , 3 , 4 ]]
129
123
self .save_path = os .path .join (self .temp_dir .name , 'squeeze_tensor' )
@@ -140,7 +134,7 @@ def test_static(self):
140
134
# axes is a list[Variable]
141
135
axes = [
142
136
paddle .full ([1 ], 0 , dtype = 'int32' ),
143
- paddle .full ([1 ], 2 , dtype = 'int32' )
137
+ paddle .full ([1 ], 2 , dtype = 'int32' ),
144
138
]
145
139
out = paddle .squeeze (feat , axes )
146
140
out2 = paddle .fluid .layers .squeeze (feat , axes )
@@ -162,5 +156,37 @@ def test_static(self):
162
156
self .assertEqual (infer_out .shape , (2 , 3 , 10 ))
163
157
164
158
159
+ # test api
160
+ class TestSqueezeAPI (unittest .TestCase ):
161
+ def setUp (self ):
162
+ self .executed_api ()
163
+
164
+ def executed_api (self ):
165
+ self .squeeze = paddle .squeeze
166
+
167
+ def test_api (self ):
168
+ paddle .disable_static ()
169
+ input_data = np .random .random ([3 , 2 , 1 ]).astype ("float32" )
170
+ x = paddle .to_tensor (input_data )
171
+ out = self .squeeze (x , axis = 2 )
172
+ out .backward ()
173
+
174
+ self .assertEqual (out .shape , [3 , 2 ])
175
+
176
+ paddle .enable_static ()
177
+
178
+ def test_error (self ):
179
+ def test_axes_type ():
180
+ x2 = paddle .static .data (name = "x2" , shape = [2 , 1 , 25 ], dtype = "int32" )
181
+ self .squeeze (x2 , axis = 2.1 )
182
+
183
+ self .assertRaises (TypeError , test_axes_type )
184
+
185
+
186
+ class TestSqueezeInplaceAPI (TestSqueezeAPI ):
187
+ def executed_api (self ):
188
+ self .squeeze = paddle .squeeze_
189
+
190
+
165
191
if __name__ == "__main__" :
166
192
unittest .main ()
0 commit comments