|
23 | 23 |
|
24 | 24 | import numpy as np
|
25 | 25 | from op_test import OpTest, convert_float_to_uint16
|
| 26 | +from utils import dygraph_guard, static_guard |
26 | 27 |
|
27 | 28 | import paddle
|
28 | 29 | import paddle.inference as paddle_infer
|
@@ -212,6 +213,12 @@ def set_attrs_input_output(self):
|
212 | 213 | self.out = self.x.cumsum(axis=0)
|
213 | 214 |
|
214 | 215 |
|
| 216 | +class TestSumOpZeroSize(TestSumOp1): |
| 217 | + def set_attrs_input_output(self): |
| 218 | + self.x = np.random.random((1, 0, 4)).astype(self._dtype) |
| 219 | + self.out = self.x.cumsum(axis=0) |
| 220 | + |
| 221 | + |
215 | 222 | @unittest.skipIf(
|
216 | 223 | core.is_compiled_with_xpu(),
|
217 | 224 | "Skip XPU for complex dtype is not fully supported",
|
@@ -747,5 +754,40 @@ def test_fp16(self):
|
747 | 754 | paddle.disable_static()
|
748 | 755 |
|
749 | 756 |
|
| 757 | +class TestSumOpDtypeAsPaddleDtype(unittest.TestCase): |
| 758 | + def setUp(self): |
| 759 | + self.shape = [2, 3, 4] |
| 760 | + self.axis = 0 |
| 761 | + self.input_dtype = 'float32' |
| 762 | + self.test_dtypes = [ |
| 763 | + paddle.int32, |
| 764 | + paddle.int64, |
| 765 | + paddle.float32, |
| 766 | + paddle.float64, |
| 767 | + paddle.bool, |
| 768 | + ] |
| 769 | + |
| 770 | + def test_dygraph(self): |
| 771 | + with dygraph_guard(): |
| 772 | + x_paddle = paddle.ones(shape=self.shape, dtype=self.input_dtype) |
| 773 | + for dtype_input in self.test_dtypes: |
| 774 | + paddle_result = paddle.cumsum( |
| 775 | + x_paddle, axis=self.axis, dtype=dtype_input |
| 776 | + ) |
| 777 | + self.assertEqual(paddle_result.dtype, dtype_input) |
| 778 | + |
| 779 | + def test_static(self): |
| 780 | + with static_guard(): |
| 781 | + for dtype_input in self.test_dtypes: |
| 782 | + with paddle.static.program_guard( |
| 783 | + paddle.static.Program(), paddle.static.Program() |
| 784 | + ): |
| 785 | + x = paddle.static.data( |
| 786 | + name='x', shape=self.shape, dtype=self.input_dtype |
| 787 | + ) |
| 788 | + result = paddle.cumsum(x, axis=self.axis, dtype=dtype_input) |
| 789 | + self.assertEqual(result.dtype, dtype_input) |
| 790 | + |
| 791 | + |
750 | 792 | if __name__ == '__main__':
|
751 | 793 | unittest.main()
|
0 commit comments