diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index be391396658fa3..504818dd3f9244 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1213,10 +1213,17 @@ struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) - dx.device(d) = - (x < static_cast(eps) || x > static_cast(1.0 - eps)) - .select(p.constant(static_cast(0)), - dout * (static_cast(1) / ((static_cast(1) - x) * x))); + if (!eps) { + dx.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) + .select(p.constant(static_cast(NAN)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } else { + dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) + .select(p.constant(static_cast(0)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } } }; @@ -3359,9 +3366,15 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { // logit(x)' = 1/(x*(1-x)) __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { MT x = static_cast(arg_x); - MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) - ? zero - : (static_cast(dout) / (x * (one - x))); + MT dx; + if (!eps) { + dx = (x < zero || x > one) ? static_cast(NAN) + : (static_cast(dout) / (x * (one - x))); + } else { + dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? zero + : (static_cast(dout) / (x * (one - x))); + } return static_cast(dx); } static constexpr ActBwdOpFwdDeps FwdDeps() { diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index b4c04c4b63e263..93c6f33c6211cb 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -24,17 +24,38 @@ def logit(x, eps): - x_min = np.minimum(x, 1.0 - eps) - x_max = np.maximum(x_min, eps) - return np.log(x_max / (1.0 - x_max)) + if eps: + x_min = np.minimum(x, 1.0 - eps) + x_max = np.maximum(x_min, eps) + return np.log(x_max / (1.0 - x_max)) + else: + return np.where( + (x < 0.0) | (x > 1.0), + np.array(np.nan, dtype=x.dtype), + np.log(x / (1.0 - x)), + ) def logit_grad(x, eps=1e-8): - tmp_x = np.select( - [x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0 - ) - x_1 = 1.0 - x - _x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0) + if eps: + tmp_x = np.select( + [x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0 + ) + x_1 = 1.0 - x + _x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0) + else: + tmp_x = np.select( + [x < 0.0, x > 1.0], + [np.array(np.nan, dtype=x.dtype), np.array(np.nan, dtype=x.dtype)], + default=-1.0, + ) + x_1 = 1.0 - x + _x = np.select( + [tmp_x == -1.0], + [np.reciprocal(x * x_1)], + default=np.array(np.nan, dtype=x.dtype), + ) + if _x.size == 0: dout = np.full_like(x, fill_value=0.0) else: @@ -162,9 +183,13 @@ def set_attrs(self): class TestLogitAPI(unittest.TestCase): - def setUp(self): + def init_data(self): self.x_shape = [120] - self.x = np.random.uniform(0.0, 1.0, self.x_shape).astype(np.float32) + self.x_dtype = "float32" + + def setUp(self): + self.init_data() + self.x = np.random.uniform(-1.0, 1.0, self.x_shape).astype(self.x_dtype) self.place = ( paddle.CUDAPlace(0) if paddle.base.core.is_compiled_with_cuda() @@ -175,22 +200,38 @@ def check_api(self, eps=1e-8): ref_out = logit(self.x, eps) # test static api with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data(name='x', shape=self.x_shape) + x = paddle.static.data( + name='x', shape=self.x_shape, dtype=self.x_dtype + ) y = paddle.logit(x, eps) exe = paddle.static.Executor(self.place) out = exe.run(feed={'x': self.x}, fetch_list=[y]) np.testing.assert_allclose(out[0], ref_out, rtol=1e-05) # test dygrapg api paddle.disable_static() - x = paddle.to_tensor(self.x) - y = paddle.logit(x, 1e-8) + x = paddle.to_tensor(self.x, dtype=self.x_dtype) + y = paddle.logit(x, eps) np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05) paddle.enable_static() + def check_api_grad(self, eps=1e-8): + ref_grad = logit_grad(self.x, eps) + numpy_tensor = np.ones(self.x_shape).astype(self.x_dtype) + # test dygrapg api + paddle.disable_static() + paddle_outgrad = paddle.to_tensor(numpy_tensor / numpy_tensor.size) + x = paddle.to_tensor(self.x, dtype=self.x_dtype) + x.stop_gradient = False + y = paddle.logit(x, eps) + x_grad = paddle.grad([y], [x], [paddle_outgrad]) + np.testing.assert_allclose(x_grad[0].numpy(), ref_grad, rtol=1e-05) + paddle.enable_static() + def test_check_api(self): paddle.enable_static() for eps in [1e-6, 0.0]: self.check_api(eps) + self.check_api_grad(eps) def test_errors(self): paddle.enable_static() @@ -202,5 +243,17 @@ def test_errors(self): self.assertRaises(TypeError, paddle.logit, x, dtype='int32') +class TestLogitAPICase1(unittest.TestCase): + def init_data(self): + self.x_shape = [120] + self.x_dtype = "float64" + + +class TestLogitAPICase2(unittest.TestCase): + def init_data(self): + self.x_shape = [120] + self.x_dtype = "float16" + + if __name__ == "__main__": unittest.main()