Skip to content

Commit 3a4ae3d

Browse files
committed
add test
1 parent d625bd0 commit 3a4ae3d

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

test/legacy_test/test_cumsum_op.py

+42
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import numpy as np
2525
from op_test import OpTest, convert_float_to_uint16
26+
from utils import dygraph_guard, static_guard
2627

2728
import paddle
2829
import paddle.inference as paddle_infer
@@ -212,6 +213,12 @@ def set_attrs_input_output(self):
212213
self.out = self.x.cumsum(axis=0)
213214

214215

216+
class TestSumOpZeroSize(TestSumOp1):
217+
def set_attrs_input_output(self):
218+
self.x = np.random.random((1, 0, 4)).astype(self._dtype)
219+
self.out = self.x.cumsum(axis=0)
220+
221+
215222
@unittest.skipIf(
216223
core.is_compiled_with_xpu(),
217224
"Skip XPU for complex dtype is not fully supported",
@@ -747,5 +754,40 @@ def test_fp16(self):
747754
paddle.disable_static()
748755

749756

757+
class TestSumOpDtypeAsPaddleDtype(unittest.TestCase):
758+
def setUp(self):
759+
self.shape = [2, 3, 4]
760+
self.axis = 0
761+
self.input_dtype = 'float32'
762+
self.test_dtypes = [
763+
paddle.int32,
764+
paddle.int64,
765+
paddle.float32,
766+
paddle.float64,
767+
paddle.bool,
768+
]
769+
770+
def test_dygraph(self):
771+
with dygraph_guard():
772+
x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype)
773+
for dtype_input in self.test_dtypes:
774+
paddle_result = paddle.cumsum(
775+
x_paddle, axis=self.axis, dtype=dtype_input
776+
)
777+
self.assertEqual(paddle_result.dtype, dtype_input)
778+
779+
def test_static(self):
780+
with static_guard():
781+
for dtype_input in self.test_dtypes:
782+
with paddle.static.program_guard(
783+
paddle.static.Program(), paddle.static.Program()
784+
):
785+
x = paddle.static.data(
786+
name='x', shape=self.shape, dtype=self.input_dtype
787+
)
788+
result = paddle.cumsum(x, axis=self.axis, dtype=dtype_input)
789+
self.assertEqual(result.dtype, dtype_input)
790+
791+
750792
if __name__ == '__main__':
751793
unittest.main()

0 commit comments

Comments
 (0)