File tree 1 file changed +25
-1
lines changed
1 file changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -606,7 +606,7 @@ def test_bad_x():
606
606
data = [1 , 2 , 4 ]
607
607
result = paddle .cumsum (data , axis = 0 )
608
608
609
- with self .assertRaises (TypeError ):
609
+ with self .assertRaises (AttributeError ):
610
610
test_bad_x ()
611
611
paddle .disable_static ()
612
612
@@ -788,5 +788,29 @@ def test_static(self):
788
788
self .assertEqual (result .dtype , dtype_input )
789
789
790
790
791
+ class TestSumOpInt32 (unittest .Testcase ):
792
+ def setUp (self ):
793
+ self .shape = [2 , 3 , 4 ]
794
+ self .axis = 0
795
+ self .input_dtype = 'int32'
796
+
797
+ def test_dygraph (self ):
798
+ with dygraph_guard ():
799
+ x = paddle .ones (shape = self .shape , dtype = self .input_dtype )
800
+ result = paddle .cumsum (x , axis = self .axis )
801
+ self .assertEqual (result .dtype , paddle .int64 )
802
+
803
+ def test_static (self ):
804
+ with static_guard ():
805
+ with paddle .static .program_guard (
806
+ paddle .static .Program (), paddle .static .Program ()
807
+ ):
808
+ x = paddle .static .data (
809
+ name = 'x' , shape = self .shape , dtype = self .input_dtype
810
+ )
811
+ result = paddle .cumsum (x , axis = self .axis )
812
+ self .assertEqual (result .dtype , paddle .int64 )
813
+
814
+
791
815
if __name__ == '__main__' :
792
816
unittest .main ()
You can’t perform that action at this time.
0 commit comments