@@ -43,7 +43,7 @@ def test_check_grad_normal(self):
43
43
class TestDropoutOpInput1d (OpTest ):
44
44
def setUp (self ):
45
45
self .op_type = "dropout"
46
- self .inputs = {'X' : np .random .random ((2000 )).astype ("float32" )}
46
+ self .inputs = {'X' : np .random .random ((2000 , )).astype ("float32" )}
47
47
self .attrs = {'dropout_prob' : 0.0 , 'fix_seed' : True , 'is_test' : False }
48
48
self .outputs = {
49
49
'Out' : self .inputs ['X' ],
@@ -672,9 +672,11 @@ def check_static_result(self, place):
672
672
res1 = paddle .nn .functional .alpha_dropout (x = input , p = 0. )
673
673
res2 = paddle .nn .functional .alpha_dropout (
674
674
x = input , p = 0. , training = False )
675
+ res3 = paddle .nn .functional .alpha_dropout (x = input , p = 1. )
675
676
676
677
in_np = np .random .random ([40 , 40 ]).astype ("float32" )
677
678
res_np = in_np
679
+ res_np3 = np .zeros_like (in_np )
678
680
679
681
exe = fluid .Executor (place )
680
682
res_list = [res1 , res2 ]
@@ -683,6 +685,10 @@ def check_static_result(self, place):
683
685
feed = {"input" : in_np },
684
686
fetch_list = [res ])
685
687
self .assertTrue (np .allclose (fetches [0 ], res_np ))
688
+ fetches = exe .run (fluid .default_main_program (),
689
+ feed = {"input" : in_np },
690
+ fetch_list = [res3 ])
691
+ self .assertTrue (np .allclose (fetches [0 ], res_np3 ))
686
692
687
693
def test_static (self ):
688
694
for place in self .places :
@@ -693,15 +699,18 @@ def test_dygraph(self):
693
699
with fluid .dygraph .guard (place ):
694
700
in_np = np .random .random ([40 , 40 ]).astype ("float32" )
695
701
res_np = in_np
702
+ res_np3 = np .zeros_like (in_np )
696
703
input = fluid .dygraph .to_variable (in_np )
697
704
698
705
res1 = paddle .nn .functional .alpha_dropout (x = input , p = 0. )
699
706
res2 = paddle .nn .functional .alpha_dropout (
700
707
x = input , p = 0. , training = False )
708
+ res3 = paddle .nn .functional .alpha_dropout (x = input , p = 1. )
701
709
702
710
res_list = [res1 , res2 ]
703
711
for res in res_list :
704
712
self .assertTrue (np .allclose (res .numpy (), res_np ))
713
+ self .assertTrue (np .allclose (res3 .numpy (), res_np3 ))
705
714
706
715
707
716
class TestAlphaDropoutFAPIError (unittest .TestCase ):
0 commit comments