Skip to content

Commit 5308047

Browse files
committed
fix the unittest problem caused by setitem doesn't support fp16
1 parent f875964 commit 5308047

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

python/paddle/fluid/tests/unittests/test_tensor_fill_.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,11 @@ def test_tensor_fill_true(self):
4040
for dtype in typelist:
4141
var = 1.
4242
tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype)
43-
newtensor = tensor.clone()
44-
newtensor[...] = var
43+
target = tensor.numpy()
44+
target[...] = var
4545

4646
tensor.fill_(var) #var type is basic type in typelist
47-
self.assertEqual((tensor.numpy() == newtensor.numpy()).all(),
48-
True)
47+
self.assertEqual((tensor.numpy() == target).all(), True)
4948

5049
def test_tensor_fill_backward(self):
5150
typelist = ['float32']

python/paddle/fluid/tests/unittests/test_tensor_zero_.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@ def test_tensor_fill_true(self):
3535
np.array(six.moves.range(np.prod(self.shape))), self.shape)
3636
for dtype in typelist:
3737
tensor = paddle.to_tensor(np_arr, place=p, dtype=dtype)
38-
newtensor = tensor.clone()
39-
newtensor[...] = 0
38+
target = tensor.numpy()
39+
target[...] = 0
4040

4141
tensor.zero_()
42-
self.assertEqual(
43-
(tensor.numpy() == newtensor.numpy()).all().item(), True)
42+
self.assertEqual((tensor.numpy() == target).all().item(), True)
4443

4544

4645
if __name__ == '__main__':

0 commit comments

Comments
 (0)