Skip to content

Commit f6316d4

Browse files
authored
[Accuracy diff No.34-35、63-64] Fix accuracy diff for logit API (#72973)
* logit * fix * fix
1 parent b8f2821 commit f6316d4

File tree

2 files changed

+86
-20
lines changed

2 files changed

+86
-20
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,10 +1213,17 @@ struct LogitGradFunctor {
12131213
template <typename Device, typename X, typename dOut, typename dX, typename P>
12141214
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
12151215
// logit(x)' = 1/(x*(1-x))
1216-
dx.device(d) =
1217-
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
1218-
.select(p.constant(static_cast<T>(0)),
1219-
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
1216+
if (!eps) {
1217+
dx.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
1218+
.select(p.constant(static_cast<T>(NAN)),
1219+
dout * (static_cast<T>(1) /
1220+
((static_cast<T>(1) - x) * x)));
1221+
} else {
1222+
dx.device(d) = (x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
1223+
.select(p.constant(static_cast<T>(0)),
1224+
dout * (static_cast<T>(1) /
1225+
((static_cast<T>(1) - x) * x)));
1226+
}
12201227
}
12211228
};
12221229

@@ -3359,9 +3366,15 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor<T> {
33593366
// logit(x)' = 1/(x*(1-x))
33603367
__device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
33613368
MT x = static_cast<MT>(arg_x);
3362-
MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
3363-
? zero
3364-
: (static_cast<MT>(dout) / (x * (one - x)));
3369+
MT dx;
3370+
if (!eps) {
3371+
dx = (x < zero || x > one) ? static_cast<T>(NAN)
3372+
: (static_cast<MT>(dout) / (x * (one - x)));
3373+
} else {
3374+
dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
3375+
? zero
3376+
: (static_cast<MT>(dout) / (x * (one - x)));
3377+
}
33653378
return static_cast<T>(dx);
33663379
}
33673380
static constexpr ActBwdOpFwdDeps FwdDeps() {

test/legacy_test/test_logit_op.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,38 @@
2424

2525

2626
def logit(x, eps):
27-
x_min = np.minimum(x, 1.0 - eps)
28-
x_max = np.maximum(x_min, eps)
29-
return np.log(x_max / (1.0 - x_max))
27+
if eps:
28+
x_min = np.minimum(x, 1.0 - eps)
29+
x_max = np.maximum(x_min, eps)
30+
return np.log(x_max / (1.0 - x_max))
31+
else:
32+
return np.where(
33+
(x < 0.0) | (x > 1.0),
34+
np.array(np.nan, dtype=x.dtype),
35+
np.log(x / (1.0 - x)),
36+
)
3037

3138

3239
def logit_grad(x, eps=1e-8):
33-
tmp_x = np.select(
34-
[x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0
35-
)
36-
x_1 = 1.0 - x
37-
_x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0)
40+
if eps:
41+
tmp_x = np.select(
42+
[x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0
43+
)
44+
x_1 = 1.0 - x
45+
_x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0)
46+
else:
47+
tmp_x = np.select(
48+
[x < 0.0, x > 1.0],
49+
[np.array(np.nan, dtype=x.dtype), np.array(np.nan, dtype=x.dtype)],
50+
default=-1.0,
51+
)
52+
x_1 = 1.0 - x
53+
_x = np.select(
54+
[tmp_x == -1.0],
55+
[np.reciprocal(x * x_1)],
56+
default=np.array(np.nan, dtype=x.dtype),
57+
)
58+
3859
if _x.size == 0:
3960
dout = np.full_like(x, fill_value=0.0)
4061
else:
@@ -162,9 +183,13 @@ def set_attrs(self):
162183

163184

164185
class TestLogitAPI(unittest.TestCase):
165-
def setUp(self):
186+
def init_data(self):
166187
self.x_shape = [120]
167-
self.x = np.random.uniform(0.0, 1.0, self.x_shape).astype(np.float32)
188+
self.x_dtype = "float32"
189+
190+
def setUp(self):
191+
self.init_data()
192+
self.x = np.random.uniform(-1.0, 1.0, self.x_shape).astype(self.x_dtype)
168193
self.place = (
169194
paddle.CUDAPlace(0)
170195
if paddle.base.core.is_compiled_with_cuda()
@@ -175,22 +200,38 @@ def check_api(self, eps=1e-8):
175200
ref_out = logit(self.x, eps)
176201
# test static api
177202
with paddle.static.program_guard(paddle.static.Program()):
178-
x = paddle.static.data(name='x', shape=self.x_shape)
203+
x = paddle.static.data(
204+
name='x', shape=self.x_shape, dtype=self.x_dtype
205+
)
179206
y = paddle.logit(x, eps)
180207
exe = paddle.static.Executor(self.place)
181208
out = exe.run(feed={'x': self.x}, fetch_list=[y])
182209
np.testing.assert_allclose(out[0], ref_out, rtol=1e-05)
183210
# test dygrapg api
184211
paddle.disable_static()
185-
x = paddle.to_tensor(self.x)
186-
y = paddle.logit(x, 1e-8)
212+
x = paddle.to_tensor(self.x, dtype=self.x_dtype)
213+
y = paddle.logit(x, eps)
187214
np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05)
188215
paddle.enable_static()
189216

217+
def check_api_grad(self, eps=1e-8):
218+
ref_grad = logit_grad(self.x, eps)
219+
numpy_tensor = np.ones(self.x_shape).astype(self.x_dtype)
220+
# test dygrapg api
221+
paddle.disable_static()
222+
paddle_outgrad = paddle.to_tensor(numpy_tensor / numpy_tensor.size)
223+
x = paddle.to_tensor(self.x, dtype=self.x_dtype)
224+
x.stop_gradient = False
225+
y = paddle.logit(x, eps)
226+
x_grad = paddle.grad([y], [x], [paddle_outgrad])
227+
np.testing.assert_allclose(x_grad[0].numpy(), ref_grad, rtol=1e-05)
228+
paddle.enable_static()
229+
190230
def test_check_api(self):
191231
paddle.enable_static()
192232
for eps in [1e-6, 0.0]:
193233
self.check_api(eps)
234+
self.check_api_grad(eps)
194235

195236
def test_errors(self):
196237
paddle.enable_static()
@@ -202,5 +243,17 @@ def test_errors(self):
202243
self.assertRaises(TypeError, paddle.logit, x, dtype='int32')
203244

204245

246+
class TestLogitAPICase1(unittest.TestCase):
247+
def init_data(self):
248+
self.x_shape = [120]
249+
self.x_dtype = "float64"
250+
251+
252+
class TestLogitAPICase2(unittest.TestCase):
253+
def init_data(self):
254+
self.x_shape = [120]
255+
self.x_dtype = "float16"
256+
257+
205258
if __name__ == "__main__":
206259
unittest.main()

0 commit comments

Comments
 (0)