Skip to content

Commit 957cbe6

Browse files
authored
fix ce error message, test=release/2.1 (#32758)
1 parent f54fb1e commit 957cbe6

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import unittest
2121
from test_softmax_op import stable_softmax
2222
from test_softmax_with_cross_entropy_op import cross_entropy
23+
from paddle.fluid import Program, program_guard
2324

2425

2526
def stable_softmax(x):
@@ -1363,5 +1364,37 @@ def test_cross_entropy_loss_2d_sum(self):
13631364
self.assertTrue(np.allclose(dy_ret_value, expected))
13641365

13651366

1367+
class TestCrossEntropyFAPIError(unittest.TestCase):
1368+
def test_errors(self):
1369+
with program_guard(Program(), Program()):
1370+
1371+
def test_LabelValue():
1372+
input_data = paddle.rand(shape=[20, 100])
1373+
label_data = paddle.randint(
1374+
0, 100, shape=[20, 1], dtype="int64")
1375+
label_data[0] = 255
1376+
weight_data = paddle.rand([100])
1377+
paddle.nn.functional.cross_entropy(
1378+
input=input_data,
1379+
label=label_data,
1380+
weight=weight_data,
1381+
ignore_index=255)
1382+
1383+
self.assertRaises(ValueError, test_LabelValue)
1384+
1385+
def test_LabelValueNeg():
1386+
input_data = paddle.rand(shape=[20, 100])
1387+
label_data = paddle.randint(
1388+
0, 100, shape=[20, 1], dtype="int64")
1389+
label_data[0] = -1
1390+
weight_data = paddle.rand([100])
1391+
paddle.nn.functional.cross_entropy(
1392+
input=input_data,
1393+
label=label_data,
1394+
weight=weight_data,
1395+
ignore_index=-1)
1396+
1397+
self.assertRaises(ValueError, test_LabelValueNeg)
1398+
13661399
if __name__ == "__main__":
13671400
unittest.main()

python/paddle/nn/functional/loss.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,6 +1411,13 @@ def cross_entropy(input,
14111411
out = core.ops.elementwise_mul(out, weight_gather_reshape)
14121412

14131413
else:
1414+
label_min = paddle.min(label)
1415+
label_max = paddle.max(label)
1416+
if label_min < 0 or label_max >= input.shape[-1]:
1417+
raise ValueError(
1418+
'Expected 0 <= label_value < class_dimension({}), but got {} <= label_value <= {} '.
1419+
format(input.shape[-1],
1420+
label_min.numpy(), label_max.numpy()))
14141421
weight_gather = core.ops.gather_nd(weight, label)
14151422
input_shape = list(label.shape)
14161423
weight_gather_reshape = reshape(

0 commit comments

Comments
 (0)