Skip to content

Commit b40613b

Browse files
committed
fix alpha dropout bug when p=1, test=develop
1 parent bcdbac1 commit b40613b

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

python/paddle/fluid/tests/unittests/test_dropout_op.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_check_grad_normal(self):
4343
class TestDropoutOpInput1d(OpTest):
4444
def setUp(self):
4545
self.op_type = "dropout"
46-
self.inputs = {'X': np.random.random((2000)).astype("float32")}
46+
self.inputs = {'X': np.random.random((2000, )).astype("float32")}
4747
self.attrs = {'dropout_prob': 0.0, 'fix_seed': True, 'is_test': False}
4848
self.outputs = {
4949
'Out': self.inputs['X'],
@@ -672,9 +672,11 @@ def check_static_result(self, place):
672672
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
673673
res2 = paddle.nn.functional.alpha_dropout(
674674
x=input, p=0., training=False)
675+
res3 = paddle.nn.functional.alpha_dropout(x=input, p=1.)
675676

676677
in_np = np.random.random([40, 40]).astype("float32")
677678
res_np = in_np
679+
res_np3 = np.zeros_like(in_np)
678680

679681
exe = fluid.Executor(place)
680682
res_list = [res1, res2]
@@ -683,6 +685,10 @@ def check_static_result(self, place):
683685
feed={"input": in_np},
684686
fetch_list=[res])
685687
self.assertTrue(np.allclose(fetches[0], res_np))
688+
fetches = exe.run(fluid.default_main_program(),
689+
feed={"input": in_np},
690+
fetch_list=[res3])
691+
self.assertTrue(np.allclose(fetches[0], res_np3))
686692

687693
def test_static(self):
688694
for place in self.places:
@@ -693,15 +699,18 @@ def test_dygraph(self):
693699
with fluid.dygraph.guard(place):
694700
in_np = np.random.random([40, 40]).astype("float32")
695701
res_np = in_np
702+
res_np3 = np.zeros_like(in_np)
696703
input = fluid.dygraph.to_variable(in_np)
697704

698705
res1 = paddle.nn.functional.alpha_dropout(x=input, p=0.)
699706
res2 = paddle.nn.functional.alpha_dropout(
700707
x=input, p=0., training=False)
708+
res3 = paddle.nn.functional.alpha_dropout(x=input, p=1.)
701709

702710
res_list = [res1, res2]
703711
for res in res_list:
704712
self.assertTrue(np.allclose(res.numpy(), res_np))
713+
self.assertTrue(np.allclose(res3.numpy(), res_np3))
705714

706715

707716
class TestAlphaDropoutFAPIError(unittest.TestCase):

python/paddle/nn/functional/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,8 @@ def alpha_dropout(x, p=0.5, training=True, name=None):
10911091
'alpha_dropout')
10921092

10931093
if training:
1094+
if p == 1:
1095+
return layers.scale(x, scale=0.)
10941096
#get transformation params
10951097
alpha = 1.6732632423543772848170429916717
10961098
scale = 1.0507009873554804934193349852946

0 commit comments

Comments
 (0)