Skip to content

Commit 9d2c9e4

Browse files
authored
[PHI] fix paddle.Tensor.logit for big tensor (#73046)
* fix logit grad * fix kernel bug and add unittest * rm rhol * fix codestyle error
1 parent 0d4566d commit 9d2c9e4

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,6 +3375,7 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor<T> {
33753375
? zero
33763376
: (static_cast<MT>(dout) / (x * (one - x)));
33773377
}
3378+
33783379
return static_cast<T>(dx);
33793380
}
33803381
static constexpr ActBwdOpFwdDeps FwdDeps() {

test/legacy_test/test_logit_op.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ def test_check_output(self):
8686

8787
def test_check_grad(self):
8888
self.check_grad(
89-
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
89+
['X'],
90+
['Out'],
91+
user_defined_grads=[self.x_grad],
92+
check_pir=True,
9093
)
9194

9295

@@ -99,11 +102,6 @@ def set_attrs(self):
99102
def test_check_output(self):
100103
self.check_output(check_pir=True)
101104

102-
def test_check_grad(self):
103-
self.check_grad(
104-
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
105-
)
106-
107105

108106
class TestLogitOpFp16(TestLogitOp):
109107
def set_attrs(self):
@@ -114,11 +112,6 @@ def set_attrs(self):
114112
def test_check_output(self):
115113
self.check_output(check_pir=True)
116114

117-
def test_check_grad(self):
118-
self.check_grad(
119-
['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True
120-
)
121-
122115

123116
@unittest.skipIf(
124117
not core.is_compiled_with_cuda()
@@ -243,6 +236,36 @@ def test_errors(self):
243236
self.assertRaises(TypeError, paddle.logit, x, dtype='int32')
244237

245238

239+
class TestLogitAPI_NAN_Val(unittest.TestCase):
240+
def setUp(self):
241+
self.init_input_output()
242+
self.place = [paddle.CPUPlace()]
243+
if paddle.base.core.is_compiled_with_cuda():
244+
self.place.append(paddle.CUDAPlace(0))
245+
246+
def init_input_output(self):
247+
self.x = [-0.1, 1.1, 2]
248+
self.expect_out = [np.nan, np.nan, np.nan]
249+
self.expect_x_grad = [np.nan, np.nan, np.nan]
250+
251+
def test_nan_val(self):
252+
def _test_nan_val_with_place(place):
253+
with paddle.base.dygraph.guard():
254+
x = paddle.to_tensor(self.x, stop_gradient=False, place=place)
255+
y = paddle.logit(x)
256+
loss = y.sum()
257+
loss.backward()
258+
np.testing.assert_allclose(
259+
y.numpy(), self.expect_out, rtol=1e-05
260+
)
261+
np.testing.assert_allclose(
262+
x.grad.numpy(), self.expect_x_grad, rtol=1e-05
263+
)
264+
265+
for place in self.place:
266+
_test_nan_val_with_place(place)
267+
268+
246269
class TestLogitAPICase1(unittest.TestCase):
247270
def init_data(self):
248271
self.x_shape = [120]

0 commit comments

Comments
 (0)