24
24
25
25
26
26
def logit (x , eps ):
27
- x_min = np .minimum (x , 1.0 - eps )
28
- x_max = np .maximum (x_min , eps )
29
- return np .log (x_max / (1.0 - x_max ))
27
+ if eps :
28
+ x_min = np .minimum (x , 1.0 - eps )
29
+ x_max = np .maximum (x_min , eps )
30
+ return np .log (x_max / (1.0 - x_max ))
31
+ else :
32
+ return np .where (
33
+ (x < 0.0 ) | (x > 1.0 ),
34
+ np .array (np .nan , dtype = x .dtype ),
35
+ np .log (x / (1.0 - x )),
36
+ )
30
37
31
38
32
39
def logit_grad (x , eps = 1e-8 ):
33
- tmp_x = np .select (
34
- [x < eps , x > (1.0 - eps )], [x * 0.0 , x * 0.0 ], default = - 1.0
35
- )
36
- x_1 = 1.0 - x
37
- _x = np .select ([tmp_x == - 1.0 ], [np .reciprocal (x * x_1 )], default = 0.0 )
40
+ if eps :
41
+ tmp_x = np .select (
42
+ [x < eps , x > (1.0 - eps )], [x * 0.0 , x * 0.0 ], default = - 1.0
43
+ )
44
+ x_1 = 1.0 - x
45
+ _x = np .select ([tmp_x == - 1.0 ], [np .reciprocal (x * x_1 )], default = 0.0 )
46
+ else :
47
+ tmp_x = np .select (
48
+ [x < 0.0 , x > 1.0 ],
49
+ [np .array (np .nan , dtype = x .dtype ), np .array (np .nan , dtype = x .dtype )],
50
+ default = - 1.0 ,
51
+ )
52
+ x_1 = 1.0 - x
53
+ _x = np .select (
54
+ [tmp_x == - 1.0 ],
55
+ [np .reciprocal (x * x_1 )],
56
+ default = np .array (np .nan , dtype = x .dtype ),
57
+ )
58
+
38
59
if _x .size == 0 :
39
60
dout = np .full_like (x , fill_value = 0.0 )
40
61
else :
@@ -162,9 +183,13 @@ def set_attrs(self):
162
183
163
184
164
185
class TestLogitAPI (unittest .TestCase ):
165
- def setUp (self ):
186
+ def init_data (self ):
166
187
self .x_shape = [120 ]
167
- self .x = np .random .uniform (0.0 , 1.0 , self .x_shape ).astype (np .float32 )
188
+ self .x_dtype = "float32"
189
+
190
+ def setUp (self ):
191
+ self .init_data ()
192
+ self .x = np .random .uniform (- 1.0 , 1.0 , self .x_shape ).astype (self .x_dtype )
168
193
self .place = (
169
194
paddle .CUDAPlace (0 )
170
195
if paddle .base .core .is_compiled_with_cuda ()
@@ -175,22 +200,38 @@ def check_api(self, eps=1e-8):
175
200
ref_out = logit (self .x , eps )
176
201
# test static api
177
202
with paddle .static .program_guard (paddle .static .Program ()):
178
- x = paddle .static .data (name = 'x' , shape = self .x_shape )
203
+ x = paddle .static .data (
204
+ name = 'x' , shape = self .x_shape , dtype = self .x_dtype
205
+ )
179
206
y = paddle .logit (x , eps )
180
207
exe = paddle .static .Executor (self .place )
181
208
out = exe .run (feed = {'x' : self .x }, fetch_list = [y ])
182
209
np .testing .assert_allclose (out [0 ], ref_out , rtol = 1e-05 )
183
210
# test dygrapg api
184
211
paddle .disable_static ()
185
- x = paddle .to_tensor (self .x )
186
- y = paddle .logit (x , 1e-8 )
212
+ x = paddle .to_tensor (self .x , dtype = self . x_dtype )
213
+ y = paddle .logit (x , eps )
187
214
np .testing .assert_allclose (y .numpy (), ref_out , rtol = 1e-05 )
188
215
paddle .enable_static ()
189
216
217
+ def check_api_grad (self , eps = 1e-8 ):
218
+ ref_grad = logit_grad (self .x , eps )
219
+ numpy_tensor = np .ones (self .x_shape ).astype (self .x_dtype )
220
+ # test dygrapg api
221
+ paddle .disable_static ()
222
+ paddle_outgrad = paddle .to_tensor (numpy_tensor / numpy_tensor .size )
223
+ x = paddle .to_tensor (self .x , dtype = self .x_dtype )
224
+ x .stop_gradient = False
225
+ y = paddle .logit (x , eps )
226
+ x_grad = paddle .grad ([y ], [x ], [paddle_outgrad ])
227
+ np .testing .assert_allclose (x_grad [0 ].numpy (), ref_grad , rtol = 1e-05 )
228
+ paddle .enable_static ()
229
+
190
230
def test_check_api (self ):
191
231
paddle .enable_static ()
192
232
for eps in [1e-6 , 0.0 ]:
193
233
self .check_api (eps )
234
+ self .check_api_grad (eps )
194
235
195
236
def test_errors (self ):
196
237
paddle .enable_static ()
@@ -202,5 +243,17 @@ def test_errors(self):
202
243
self .assertRaises (TypeError , paddle .logit , x , dtype = 'int32' )
203
244
204
245
246
+ class TestLogitAPICase1 (unittest .TestCase ):
247
+ def init_data (self ):
248
+ self .x_shape = [120 ]
249
+ self .x_dtype = "float64"
250
+
251
+
252
+ class TestLogitAPICase2 (unittest .TestCase ):
253
+ def init_data (self ):
254
+ self .x_shape = [120 ]
255
+ self .x_dtype = "float16"
256
+
257
+
205
258
if __name__ == "__main__" :
206
259
unittest .main ()
0 commit comments