Skip to content

Commit 52119d6

Browse files
committed
refine
1 parent a1e1ae3 commit 52119d6

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

paddle/operators/dropout_op.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
4747
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
4848
.SetDefault(.5f)
4949
.AddCustomChecker([](const float& drop_p) {
50-
PADDLE_ENFORCE(drop_p > 0.0f && drop_p < 1.0f,
51-
"'dropout_prob' must be between 0 and 1.");
50+
PADDLE_ENFORCE(drop_p >= 0.0f && drop_p <= 1.0f,
51+
"'dropout_prob' must be between 0.0 and 1.0.");
5252
});
5353
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
5454
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);

paddle/operators/dropout_op.cu

+2-3
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,15 @@ struct MaskGenerator {
3030
__host__ __device__ MaskGenerator(AttrType dropout_prob, int seed)
3131
: dropout_prob(dropout_prob), seed(seed) {}
3232

33-
__host__ __device__ T operator()(const unsigned int n) const {
33+
inline __host__ __device__ T operator()(const unsigned int n) const {
3434
thrust::minstd_rand rng;
3535
rng.seed(seed);
3636
thrust::uniform_real_distribution<AttrType> dist(0, 1);
3737
rng.discard(n);
3838
if (dist(rng) < dropout_prob) {
3939
return static_cast<T>(0);
40-
} else {
41-
return static_cast<T>(1);
4240
}
41+
return static_cast<T>(1);
4342
}
4443
};
4544

0 commit comments

Comments
 (0)