From ebcdf6dbf365fa8a660ecd616eb60c5c70a97638 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 28 May 2025 11:19:23 +0800 Subject: [PATCH 1/3] logit --- paddle/phi/kernels/funcs/activation_functor.h | 26 ++++++++--- test/legacy_test/test_logit_op.py | 45 +++++++++++++------ 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index be391396658fa3..65fc30c3633d2d 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1213,10 +1213,17 @@ struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) - dx.device(d) = - (x < static_cast(eps) || x > static_cast(1.0 - eps)) - .select(p.constant(static_cast(0)), - dout * (static_cast(1) / ((static_cast(1) - x) * x))); + if (!eps) { + dx.device(d) = (x < static_cast(0.0) || x > static_cast(1.0)) + .select(p.constant(static_cast(NAN)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } else { + dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) + .select(p.constant(static_cast(0)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } } }; @@ -3359,9 +3366,14 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { // logit(x)' = 1/(x*(1-x)) __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { MT x = static_cast(arg_x); - MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) - ? zero - : (static_cast(dout) / (x * (one - x))); + if (!eps) { + MT dx = (x < zero || x > one) ? static_cast(NAN) + : (static_cast(dout) / (x * (one - x))); + } else { + MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? zero + : (static_cast(dout) / (x * (one - x))); + } return static_cast(dx); } static constexpr ActBwdOpFwdDeps FwdDeps() { diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index 5ab56c831c8c20..503ee4a29311a8 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -24,15 +24,29 @@ def logit(x, eps): - x_min = np.minimum(x, 1.0 - eps) - x_max = np.maximum(x_min, eps) - return np.log(x_max / (1.0 - x_max)) + if eps: + x_min = np.minimum(x, 1.0 - eps) + x_max = np.maximum(x_min, eps) + return np.log(x_max / (1.0 - x_max)) + else: + return np.where( + (x < 0.0) | (x > 1.0), + np.array(np.nan, dtype=x.dtype), + np.log(x / (1.0 - x)), + ) def logit_grad(x, eps=1e-8): - tmp_x = np.select( - [x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0 - ) + if eps: + tmp_x = np.select( + [x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0 + ) + else: + tmp_x = np.select( + [x < 0.0, x > 1.0], + [np.array(np.nan, dtype=x.dtype), np.array(np.nan, dtype=x.dtype)], + default=-1.0, + ) x_1 = 1.0 - x _x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0) dout = np.full_like(x, fill_value=1.0 / _x.size) @@ -58,11 +72,16 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output(check_pir=True, check_symbol_infer=False) + self.check_output( + check_pir=True, check_symbol_infer=False, equal_nan=True + ) def test_check_grad(self): self.check_grad( - ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ['X'], + ['Out'], + user_defined_grads=[self.x_grad], + check_pir=True, ) @@ -73,7 +92,7 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, equal_nan=True) def test_check_grad(self): self.check_grad( @@ -88,7 +107,7 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, equal_nan=True) def test_check_grad(self): self.check_grad( @@ -122,7 +141,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_pir=True, check_symbol_infer=False + place, check_pir=True, check_symbol_infer=False, equal_nan=True ) def test_check_grad(self): @@ -141,7 +160,7 @@ class TestLogitShape(TestLogitOp): def set_attrs(self): self.dtype = np.float64 self.shape = [2, 60] - self.eps = 1e-8 + self.eps = 0.0 class TestLogitEps(TestLogitOp): @@ -173,7 +192,7 @@ def check_api(self, eps=1e-8): # test dygrapg api paddle.disable_static() x = paddle.to_tensor(self.x) - y = paddle.logit(x, 1e-8) + y = paddle.logit(x, eps) np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05) paddle.enable_static() From 01b5509fde2a96fb64b3191f1824067c4ce0f042 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 28 May 2025 12:46:13 +0800 Subject: [PATCH 2/3] fix --- test/legacy_test/test_logit_op.py | 80 ++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index 503ee4a29311a8..93c6f33c6211cb 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -41,15 +41,25 @@ def logit_grad(x, eps=1e-8): tmp_x = np.select( [x < eps, x > (1.0 - eps)], [x * 0.0, x * 0.0], default=-1.0 ) + x_1 = 1.0 - x + _x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0) else: tmp_x = np.select( [x < 0.0, x > 1.0], [np.array(np.nan, dtype=x.dtype), np.array(np.nan, dtype=x.dtype)], default=-1.0, ) - x_1 = 1.0 - x - _x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0) - dout = np.full_like(x, fill_value=1.0 / _x.size) + x_1 = 1.0 - x + _x = np.select( + [tmp_x == -1.0], + [np.reciprocal(x * x_1)], + default=np.array(np.nan, dtype=x.dtype), + ) + + if _x.size == 0: + dout = np.full_like(x, fill_value=0.0) + else: + dout = np.full_like(x, fill_value=1.0 / _x.size) dx = dout * _x return dx @@ -72,16 +82,11 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output( - check_pir=True, check_symbol_infer=False, equal_nan=True - ) + self.check_output(check_pir=True, check_symbol_infer=False) def test_check_grad(self): self.check_grad( - ['X'], - ['Out'], - user_defined_grads=[self.x_grad], - check_pir=True, + ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True ) @@ -92,7 +97,7 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output(check_pir=True, equal_nan=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -107,7 +112,7 @@ def set_attrs(self): self.eps = 1e-8 def test_check_output(self): - self.check_output(check_pir=True, equal_nan=True) + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( @@ -141,7 +146,7 @@ def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) self.check_output_with_place( - place, check_pir=True, check_symbol_infer=False, equal_nan=True + place, check_pir=True, check_symbol_infer=False ) def test_check_grad(self): @@ -160,7 +165,7 @@ class TestLogitShape(TestLogitOp): def set_attrs(self): self.dtype = np.float64 self.shape = [2, 60] - self.eps = 0.0 + self.eps = 1e-8 class TestLogitEps(TestLogitOp): @@ -170,10 +175,21 @@ def set_attrs(self): self.eps = 1e-8 +class TestLogit_ZeroSize(TestLogitOp): + def set_attrs(self): + self.dtype = np.float64 + self.shape = [2, 0] + self.eps = 1e-8 + + class TestLogitAPI(unittest.TestCase): - def setUp(self): + def init_data(self): self.x_shape = [120] - self.x = np.random.uniform(0.0, 1.0, self.x_shape).astype(np.float32) + self.x_dtype = "float32" + + def setUp(self): + self.init_data() + self.x = np.random.uniform(-1.0, 1.0, self.x_shape).astype(self.x_dtype) self.place = ( paddle.CUDAPlace(0) if paddle.base.core.is_compiled_with_cuda() @@ -184,22 +200,38 @@ def check_api(self, eps=1e-8): ref_out = logit(self.x, eps) # test static api with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data(name='x', shape=self.x_shape) + x = paddle.static.data( + name='x', shape=self.x_shape, dtype=self.x_dtype + ) y = paddle.logit(x, eps) exe = paddle.static.Executor(self.place) out = exe.run(feed={'x': self.x}, fetch_list=[y]) np.testing.assert_allclose(out[0], ref_out, rtol=1e-05) # test dygrapg api paddle.disable_static() - x = paddle.to_tensor(self.x) + x = paddle.to_tensor(self.x, dtype=self.x_dtype) y = paddle.logit(x, eps) np.testing.assert_allclose(y.numpy(), ref_out, rtol=1e-05) paddle.enable_static() + def check_api_grad(self, eps=1e-8): + ref_grad = logit_grad(self.x, eps) + numpy_tensor = np.ones(self.x_shape).astype(self.x_dtype) + # test dygrapg api + paddle.disable_static() + paddle_outgrad = paddle.to_tensor(numpy_tensor / numpy_tensor.size) + x = paddle.to_tensor(self.x, dtype=self.x_dtype) + x.stop_gradient = False + y = paddle.logit(x, eps) + x_grad = paddle.grad([y], [x], [paddle_outgrad]) + np.testing.assert_allclose(x_grad[0].numpy(), ref_grad, rtol=1e-05) + paddle.enable_static() + def test_check_api(self): paddle.enable_static() for eps in [1e-6, 0.0]: self.check_api(eps) + self.check_api_grad(eps) def test_errors(self): paddle.enable_static() @@ -211,5 +243,17 @@ def test_errors(self): self.assertRaises(TypeError, paddle.logit, x, dtype='int32') +class TestLogitAPICase1(unittest.TestCase): + def init_data(self): + self.x_shape = [120] + self.x_dtype = "float64" + + +class TestLogitAPICase2(unittest.TestCase): + def init_data(self): + self.x_shape = [120] + self.x_dtype = "float16" + + if __name__ == "__main__": unittest.main() From 6b3dfc434702cb45213983b4e04117a01652f3d5 Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Wed, 28 May 2025 14:14:41 +0800 Subject: [PATCH 3/3] fix --- paddle/phi/kernels/funcs/activation_functor.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 65fc30c3633d2d..504818dd3f9244 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -3366,13 +3366,14 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { // logit(x)' = 1/(x*(1-x)) __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { MT x = static_cast(arg_x); + MT dx; if (!eps) { - MT dx = (x < zero || x > one) ? static_cast(NAN) - : (static_cast(dout) / (x * (one - x))); + dx = (x < zero || x > one) ? static_cast(NAN) + : (static_cast(dout) / (x * (one - x))); } else { - MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) - ? zero - : (static_cast(dout) / (x * (one - x))); + dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? zero + : (static_cast(dout) / (x * (one - x))); } return static_cast(dx); }