Skip to content

Commit 2ded875

Browse files
authored
fix nonzero (PaddlePaddle#72003)
* fix * fix test * fix
1 parent ff36637 commit 2ded875

File tree

4 files changed

+11
-19
lines changed

4 files changed

+11
-19
lines changed

python/paddle/tensor/search.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,6 @@ def nonzero(x: Tensor, as_tuple=False):
523523
[3]])
524524
525525
"""
526-
list_out = []
527-
shape = x.shape
528-
rank = len(shape)
529-
530526
if in_dynamic_or_pir_mode():
531527
outs = _C_ops.nonzero(x)
532528
else:
@@ -558,13 +554,9 @@ def nonzero(x: Tensor, as_tuple=False):
558554

559555
if not as_tuple:
560556
return outs
561-
elif rank == 1:
562-
return (outs,)
563557
else:
564-
for i in range(rank):
565-
list_out.append(
566-
paddle.slice(outs, axes=[1], starts=[i], ends=[i + 1])
567-
)
558+
rank = x.ndim
559+
list_out = [outs[:, i] for i in range(rank)]
568560
return tuple(list_out)
569561

570562

test/legacy_test/test_nonzero_api.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ def test_nonzero_api_as_tuple(self):
3838
y = paddle.nonzero(x, as_tuple=True)
3939
self.assertEqual(type(y), tuple)
4040
self.assertEqual(len(y), 2)
41-
z = paddle.concat(list(y), axis=1)
41+
z = paddle.concat(list(y), axis=0)
4242
exe = base.Executor(base.CPUPlace())
4343

4444
(res,) = exe.run(
4545
feed={'x': data}, fetch_list=[z], return_numpy=False
4646
)
47-
expect_out = np.array([[0, 0], [1, 1]])
47+
expect_out = np.array([0, 1, 0, 1])
4848
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
4949

5050
data = np.array([1, 1, 0], dtype="float32")
@@ -55,12 +55,12 @@ def test_nonzero_api_as_tuple(self):
5555
y = paddle.nonzero(x, as_tuple=True)
5656
self.assertEqual(type(y), tuple)
5757
self.assertEqual(len(y), 1)
58-
z = paddle.concat(list(y), axis=1)
58+
z = paddle.concat(list(y), axis=0)
5959
exe = base.Executor(base.CPUPlace())
6060
(res,) = exe.run(
6161
feed={'x': data}, fetch_list=[z], return_numpy=False
6262
)
63-
expect_out = np.array([[0], [1]])
63+
expect_out = np.array([0, 1])
6464
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
6565

6666
def test_nonzero_api(self):

test/legacy_test/test_where_op.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -788,12 +788,12 @@ def test_where_condition(self):
788788
y = paddle.where(x)
789789
self.assertEqual(type(y), tuple)
790790
self.assertEqual(len(y), 2)
791-
z = paddle.concat(list(y), axis=1)
791+
z = paddle.concat(list(y), axis=0)
792792
exe = base.Executor(base.CPUPlace())
793793
(res,) = exe.run(
794794
feed={'x': data}, fetch_list=[z], return_numpy=False
795795
)
796-
expect_out = np.array([[0, 0], [1, 1]])
796+
expect_out = np.array([0, 1, 0, 1])
797797
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
798798
data = np.array([True, True, False])
799799
with program_guard(Program(), Program()):
@@ -803,12 +803,12 @@ def test_where_condition(self):
803803
y = paddle.where(x)
804804
self.assertEqual(type(y), tuple)
805805
self.assertEqual(len(y), 1)
806-
z = paddle.concat(list(y), axis=1)
806+
z = paddle.concat(list(y), axis=0)
807807
exe = base.Executor(base.CPUPlace())
808808
(res,) = exe.run(
809809
feed={'x': data}, fetch_list=[z], return_numpy=False
810810
)
811-
expect_out = np.array([[0], [1]])
811+
expect_out = np.array([0, 1])
812812
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
813813

814814

test/xpu/test_where_index_xpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_errors(self):
112112
def test_type():
113113
paddle.nonzero([10])
114114

115-
self.assertRaises(AttributeError, test_type)
115+
self.assertRaises(TypeError, test_type)
116116

117117

118118
class TestWhereSimulatorMode(unittest.TestCase):

0 commit comments

Comments
 (0)