Skip to content

Commit c9113ee

Browse files
committed
fix
1 parent 27018cc commit c9113ee

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

test/legacy_test/test_cumsum_op.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def test_bad_x():
606606
data = [1, 2, 4]
607607
result = paddle.cumsum(data, axis=0)
608608

609-
with self.assertRaises(TypeError):
609+
with self.assertRaises(AttributeError):
610610
test_bad_x()
611611
paddle.disable_static()
612612

@@ -788,5 +788,29 @@ def test_static(self):
788788
self.assertEqual(result.dtype, dtype_input)
789789

790790

791+
class TestSumOpInt32(unittest.Testcase):
792+
def setUp(self):
793+
self.shape = [2, 3, 4]
794+
self.axis = 0
795+
self.input_dtype = 'int32'
796+
797+
def test_dygraph(self):
798+
with dygraph_guard():
799+
x = paddle.ones(shape=self.shape, dtype=self.input_dtype)
800+
result = paddle.cumsum(x, axis=self.axis)
801+
self.assertEqual(result.dtype, paddle.int64)
802+
803+
def test_static(self):
804+
with static_guard():
805+
with paddle.static.program_guard(
806+
paddle.static.Program(), paddle.static.Program()
807+
):
808+
x = paddle.static.data(
809+
name='x', shape=self.shape, dtype=self.input_dtype
810+
)
811+
result = paddle.cumsum(x, axis=self.axis)
812+
self.assertEqual(result.dtype, paddle.int64)
813+
814+
791815
if __name__ == '__main__':
792816
unittest.main()

0 commit comments

Comments
 (0)