Skip to content

[PHI] fix paddle.Tensor.logit for big tensor #73046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1213,10 +1213,17 @@ struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
// logit(x)' = 1/(x*(1-x))
dx.device(d) =
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
if (!eps) {
dx.device(d) = (x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(NAN)),
dout * (static_cast<T>(1) /
((static_cast<T>(1) - x) * x)));
} else {
dx.device(d) = (x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) /
((static_cast<T>(1) - x) * x)));
}
}
};

Expand Down Expand Up @@ -3352,16 +3359,25 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor<T> {
float eps;
MT zero = static_cast<MT>(0.0f);
MT one = static_cast<MT>(1.0f);
MT nan = static_cast<MT>(NAN);

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"eps", &eps}};
}
// logit(x)' = 1/(x*(1-x))
__device__ __forceinline__ T operator()(const T dout, const T arg_x) const {
MT x = static_cast<MT>(arg_x);
MT dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
? zero
: (static_cast<MT>(dout) / (x * (one - x)));
MT dx;
if (!eps) {
dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
? nan
: (static_cast<MT>(dout) / (x * (one - x)));
} else {
dx = (x < static_cast<MT>(eps) || x > one - static_cast<MT>(eps))
? zero
: (static_cast<MT>(dout) / (x * (one - x)));
}

return static_cast<T>(dx);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
Expand Down
45 changes: 34 additions & 11 deletions test/legacy_test/test_logit_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ def test_check_output(self):

def test_check_grad(self):
self.check_grad(
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
['X'],
['Out'],
user_defined_grads=[self.x_grad],
check_pir=True,
)


Expand All @@ -78,11 +81,6 @@ def set_attrs(self):
def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
)


class TestLogitOpFp16(TestLogitOp):
def set_attrs(self):
Expand All @@ -93,11 +91,6 @@ def set_attrs(self):
def test_check_output(self):
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
Expand Down Expand Up @@ -202,5 +195,35 @@ def test_errors(self):
self.assertRaises(TypeError, paddle.logit, x, dtype='int32')


class TestLogitAPI_NAN_Val(unittest.TestCase):
def setUp(self):
self.init_input_output()
self.place = [paddle.CPUPlace()]
if paddle.base.core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))

def init_input_output(self):
self.x = [-0.1, 1.1, 2]
self.expect_out = [np.nan, np.nan, np.nan]
self.expect_x_grad = [np.nan, np.nan, np.nan]

def test_nan_val(self):
def _test_nan_val_with_place(place):
with paddle.base.dygraph.guard():
x = paddle.to_tensor(self.x, stop_gradient=False, place=place)
y = paddle.logit(x)
loss = y.sum()
loss.backward()
np.testing.assert_allclose(
y.numpy(), self.expect_out, rtol=1e-05
)
np.testing.assert_allclose(
x.grad.numpy(), self.expect_x_grad, rtol=1e-05
)

for place in self.place:
_test_nan_val_with_place(place)


if __name__ == "__main__":
unittest.main()
Loading