File tree 2 files changed +4
-5
lines changed
2 files changed +4
-5
lines changed Original file line number Diff line number Diff line change @@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
47
47
AddAttr<float >(" dropout_prob" , " Probability of setting units to zero." )
48
48
.SetDefault (.5f )
49
49
.AddCustomChecker ([](const float & drop_p) {
50
- PADDLE_ENFORCE (drop_p > 0 .0f && drop_p < 1 .0f ,
51
- " 'dropout_prob' must be between 0 and 1." );
50
+ PADDLE_ENFORCE (drop_p >= 0 .0f && drop_p <= 1 .0f ,
51
+ " 'dropout_prob' must be between 0.0 and 1.0 ." );
52
52
});
53
53
AddAttr<bool >(" is_test" , " True if in test phase." ).SetDefault (false );
54
54
AddAttr<int >(" seed" , " Dropout random seed." ).SetDefault (0 );
Original file line number Diff line number Diff line change @@ -30,16 +30,15 @@ struct MaskGenerator {
30
30
__host__ __device__ MaskGenerator (AttrType dropout_prob, int seed)
31
31
: dropout_prob(dropout_prob), seed(seed) {}
32
32
33
- __host__ __device__ T operator ()(const unsigned int n) const {
33
+ inline __host__ __device__ T operator ()(const unsigned int n) const {
34
34
thrust::minstd_rand rng;
35
35
rng.seed (seed);
36
36
thrust::uniform_real_distribution<AttrType> dist (0 , 1 );
37
37
rng.discard (n);
38
38
if (dist (rng) < dropout_prob) {
39
39
return static_cast <T>(0 );
40
- } else {
41
- return static_cast <T>(1 );
42
40
}
41
+ return static_cast <T>(1 );
43
42
}
44
43
};
45
44
You can’t perform that action at this time.
0 commit comments