Skip to content

Commit 5ffc22d

Browse files
authored
[Zero-Dim]Paddle.t support 0d tensor (#49880)
* support paddle.t 0d tensor * fix paddle.t test case * merge from develop
1 parent 7242f40 commit 5ffc22d

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py

100755100644
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,17 @@ def test_maseked_select(self):
13451345
self.assertEqual(x.grad.shape, [])
13461346
self.assertEqual(x.grad.numpy(), 1)
13471347

1348+
def test_t(self):
1349+
x = paddle.full([], 2.0)
1350+
x.stop_gradient = False
1351+
x.retain_grads()
1352+
out = paddle.t(x)
1353+
out.retain_grads()
1354+
out.backward()
1355+
self.assertEqual(out.shape, [])
1356+
self.assertEqual(out.grad.shape, [])
1357+
self.assertEqual(x.grad.shape, [])
1358+
13481359

13491360
class TestSundryAPIStatic(unittest.TestCase):
13501361
def setUp(self):
@@ -2080,6 +2091,21 @@ def test_maseked_select(self):
20802091
self.assertEqual(res[3].shape, ())
20812092
self.assertEqual(res[3], 1)
20822093

2094+
@prog_scope()
2095+
def test_t(self):
2096+
x = paddle.full([], 2.0)
2097+
x.stop_gradient = False
2098+
out = paddle.t(x)
2099+
paddle.static.append_backward(out.sum())
2100+
prog = paddle.static.default_main_program()
2101+
res = self.exe.run(
2102+
prog, feed={}, fetch_list=[out, out.grad_name, x.grad_name]
2103+
)
2104+
2105+
self.assertEqual(res[0].shape, ())
2106+
self.assertEqual(res[1].shape, ())
2107+
self.assertEqual(res[2].shape, ())
2108+
20832109

20842110
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
20852111
class TestNoBackwardAPI(unittest.TestCase):

python/paddle/tensor/linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,7 +1296,7 @@ def t(input, name=None):
12961296
"tensor.transpose() instead." % len(input.shape)
12971297
)
12981298
if in_dygraph_mode():
1299-
if len(input.shape) == 1:
1299+
if len(input.shape) <= 1:
13001300
return input
13011301
# 2-D tensor
13021302
perm = [1, 0]
@@ -1313,7 +1313,7 @@ def t(input, name=None):
13131313
helper = LayerHelper('t', **locals())
13141314
out = helper.create_variable_for_type_inference(input.dtype)
13151315
input_shape = helper.create_variable_for_type_inference(input.dtype)
1316-
if len(input.shape) == 1:
1316+
if len(input.shape) <= 1:
13171317
out = input
13181318
else:
13191319
helper.append_op(

0 commit comments

Comments
 (0)