Skip to content

Commit 05d7e2f

Browse files
authored
Add fused_dropout wrapper to ease use. (#36185) (#36640)
In fused_attention op and fused_ffn op, the fused bias_add+dropout+residual+layernorm kernel or bias_add+dropout+residual kernel is used. To ease the use of this kernel, we provide a wrapper in this PR. 1.To reuse the increment computing code, we exact the corresponding code to "GetSeedDataAndIncrement" routine in dropout_impl_util.h. 2.The fused_dropout_helper.h provides the fused dropout kernel wrapper. Note: the test of this warper will be provided in the following fused_attention_op and fused_ffn PRs.
1 parent 1906c74 commit 05d7e2f

File tree

3 files changed

+338
-22
lines changed

3 files changed

+338
-22
lines changed

paddle/fluid/operators/dropout_impl.cu.h

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License. */
3030
#include "paddle/fluid/framework/eigen.h"
3131
#include "paddle/fluid/framework/generator.h"
3232
#include "paddle/fluid/framework/tensor_util.h"
33+
#include "paddle/fluid/operators/dropout_impl_util.h"
3334
#include "paddle/fluid/operators/dropout_op.h"
3435
#include "paddle/fluid/platform/aligned_vector.h"
3536
#include "paddle/fluid/platform/gpu_launch_config.h"
@@ -196,28 +197,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
196197
config.thread_per_block.x * vec_size) +
197198
1) *
198199
vec_size;
199-
int device_id =
200-
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
201-
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
202-
203-
if ((seed) && platform::is_gpu_place(seed->place())) {
204-
framework::Tensor seed_cpu_tensor;
205-
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
206-
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
207-
increment = offset;
208-
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
209-
auto seed_offset = gen_cuda->IncrementOffset(offset);
210-
seed_data = seed_offset.first;
211-
increment = seed_offset.second;
212-
} else {
213-
if (seed) {
214-
seed_data = *(seed->data<int>());
215-
} else {
216-
std::random_device rnd;
217-
seed_data = is_fix_seed ? seed_val : rnd();
218-
}
219-
increment = offset;
220-
}
200+
GetSeedDataAndIncrement(dev_ctx, seed, is_fix_seed, seed_val, offset,
201+
&seed_data, &increment);
221202

222203
#ifdef __HIPCC__
223204
if (vec_size == 4 && size % 4 == 0) {
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/generator.h"
18+
#include "paddle/fluid/framework/tensor_util.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
24+
const framework::Tensor* seed,
25+
const bool is_fix_seed, const int seed_val,
26+
const int offset, uint64_t* seed_data,
27+
uint64_t* increment) {
28+
int device_id =
29+
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
30+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
31+
32+
if ((seed) && platform::is_gpu_place(seed->place())) {
33+
framework::Tensor seed_cpu_tensor;
34+
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
35+
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
36+
*increment = offset;
37+
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
38+
auto seed_offset = gen_cuda->IncrementOffset(offset);
39+
*seed_data = seed_offset.first;
40+
*increment = seed_offset.second;
41+
} else {
42+
if (seed) {
43+
*seed_data = *(seed->data<int>());
44+
} else {
45+
std::random_device rnd;
46+
*seed_data = is_fix_seed ? seed_val : rnd();
47+
}
48+
*increment = offset;
49+
}
50+
}
51+
52+
} // namespace operators
53+
} // namespace paddle
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/generator.h"
18+
#include "paddle/fluid/operators/dropout_impl_util.h"
19+
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
20+
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
21+
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
22+
#include "paddle/fluid/operators/math/functors.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
27+
/**
28+
* Support two Dropouts in the use senarieo.
29+
* This warpper can be used in FFN op.
30+
* The DropoutParam will be used in the fused_dropout_act_bias,
31+
* fused_residual_dropout_bias(pre_layer_norm=ture) or
32+
* fused_layernorm_residual_dropout_bias(pre_layer_norm=false).
33+
*/
34+
struct DropoutParam {
35+
uint64_t seed;
36+
float dropout_prob;
37+
bool is_upscale_in_train;
38+
bool is_test;
39+
bool fix_seed;
40+
int increment;
41+
const framework::Tensor* tensor_seed;
42+
int seed_val;
43+
44+
DropoutParam() {
45+
fix_seed = false;
46+
seed = 0;
47+
is_test = false;
48+
is_upscale_in_train = false;
49+
dropout_prob = 0.5;
50+
tensor_seed = nullptr;
51+
seed_val = 0;
52+
}
53+
54+
/**
55+
* dropout_index: can be 0, 1, 2. 0 means there is only one dropout,
56+
* 1 and 2 represent two dropout, the parameter name of dropout
57+
* will be "dropout" + dropout_index + param name, such as dropout1_seed,
58+
* dropout1_is_test.
59+
*/
60+
DropoutParam(const framework::ExecutionContext& context,
61+
const int dropout_index) {
62+
std::string pre_fix = "dropout";
63+
std::string str_index = std::to_string(dropout_index);
64+
if (dropout_index > 0) {
65+
pre_fix = pre_fix + str_index + "_";
66+
} else {
67+
pre_fix = pre_fix + "_";
68+
}
69+
dropout_prob = context.Attr<float>(pre_fix + "prob");
70+
auto& dropout_implementation =
71+
context.Attr<std::string>(pre_fix + "implementation");
72+
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
73+
is_test = context.Attr<bool>(pre_fix + "is_test");
74+
fix_seed = context.Attr<bool>(pre_fix + "fix_seed");
75+
76+
std::string str_seed = "Dropout";
77+
if (dropout_index > 0) {
78+
str_seed = str_seed + str_index + "Seed";
79+
} else {
80+
str_seed = str_seed + "Seed";
81+
}
82+
tensor_seed =
83+
context.HasInput(str_seed) ? context.Input<Tensor>(str_seed) : nullptr;
84+
seed_val = context.Attr<int>(pre_fix + "seed");
85+
}
86+
87+
int UpdateSeedAndIncrement(const platform::CUDADeviceContext& ctx,
88+
const int offset) {
89+
uint64_t tmp_increment;
90+
GetSeedDataAndIncrement(ctx, tensor_seed, fix_seed, seed_val, offset, &seed,
91+
&tmp_increment);
92+
increment = static_cast<int>(tmp_increment);
93+
return increment;
94+
}
95+
};
96+
97+
template <typename T, typename MaskType>
98+
class FusedDropoutHelper {
99+
private:
100+
int GetIncrement(const platform::CUDADeviceContext& ctx) {
101+
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
102+
const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1;
103+
auto config =
104+
Get1DBlocksAnd2DGrids(ctx, static_cast<uint64_t>(rows_),
105+
static_cast<uint64_t>(cols_), real_vec_size);
106+
int increment = ((cols_ - 1) / (config.thread_per_block.x *
107+
config.block_per_grid.x * real_vec_size) +
108+
1) *
109+
real_vec_size;
110+
increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment);
111+
return increment;
112+
}
113+
114+
public:
115+
FusedDropoutHelper() {}
116+
FusedDropoutHelper(const platform::CUDADeviceContext& ctx, const int rows,
117+
const int cols, const DropoutParam& dropout_param) {
118+
rows_ = rows;
119+
cols_ = cols;
120+
dropout_param_ = dropout_param;
121+
}
122+
123+
// out = residual + dropout( src + bias )
124+
void ResidualDropoutBias(const platform::CUDADeviceContext& ctx, const T* src,
125+
const T* residual, const T* bias, T* out,
126+
MaskType* mask) {
127+
auto increment = GetIncrement(ctx);
128+
LaunchResidualDropoutBias<T, MaskType>(
129+
rows_, cols_, increment, dropout_param_.seed,
130+
dropout_param_.dropout_prob, dropout_param_.is_test,
131+
dropout_param_.is_upscale_in_train, src, residual, bias, mask, out,
132+
ctx);
133+
}
134+
135+
void ResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx,
136+
const T* d_out, const MaskType* mask, T* d_src,
137+
T* d_residual, T* d_bias) {
138+
LaunchResidualDropoutBiasGrad<T, uint8_t>(
139+
d_out, mask, dropout_param_.dropout_prob,
140+
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
141+
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
142+
memory::Copy(cuda_place, d_residual, cuda_place, d_out,
143+
rows_ * cols_ * sizeof(T), ctx.stream());
144+
}
145+
146+
// out = dropout(activation(src + bias))
147+
void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src,
148+
const T* bias, const std::string& act_method, T* out,
149+
MaskType* mask) {
150+
auto increment = GetIncrement(ctx);
151+
if (act_method == "gelu") {
152+
GeluFunctor<T> gelu;
153+
LaunchDropoutActBias<T, MaskType, GeluFunctor<T>>(
154+
gelu, dropout_param_.seed, rows_, cols_, dropout_param_.increment,
155+
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
156+
dropout_param_.is_test, src, bias, out, mask, ctx);
157+
} else if (act_method == "relu") {
158+
math::ReluFunctor<T> relu;
159+
LaunchDropoutActBias<T, MaskType, math::ReluFunctor<T>>(
160+
relu, dropout_param_.seed, rows_, cols_, increment,
161+
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
162+
dropout_param_.is_test, src, bias, out, mask, ctx);
163+
} else {
164+
PADDLE_THROW(platform::errors::InvalidArgument(
165+
"Currently only supports gelu or relu activation functions!"));
166+
}
167+
}
168+
169+
void DropoutActBiasGrad(const platform::CUDADeviceContext& ctx, const T* dout,
170+
const T* src, const T* bias, const MaskType* mask,
171+
T* d_src, T* d_bias, const std::string& act_method) {
172+
if (act_method == "gelu") {
173+
GeluGradFunctor<T> gelu_grad;
174+
LaunchDropoutActBiasGrad<T, MaskType, GeluGradFunctor<T>>(
175+
gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
176+
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
177+
} else if (act_method == "relu") {
178+
math::ReluGradFunctor<T> relu_grad;
179+
LaunchDropoutActBiasGrad<T, MaskType, math::ReluGradFunctor<T>>(
180+
relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
181+
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
182+
} else {
183+
PADDLE_THROW(platform::errors::InvalidArgument(
184+
"Currently only supports gelu or relu activation functions!"));
185+
}
186+
}
187+
188+
protected:
189+
int rows_;
190+
int cols_;
191+
DropoutParam dropout_param_;
192+
};
193+
194+
template <typename T, typename MaskType>
195+
class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> {
196+
public:
197+
FusedDropoutLayerNormHelper() {}
198+
FusedDropoutLayerNormHelper(const int rows, const int cols,
199+
const float epsilon) {
200+
using U = LayerNormParamType<T>;
201+
this->rows_ = rows;
202+
this->cols_ = cols;
203+
epsilon_ = epsilon;
204+
}
205+
206+
FusedDropoutLayerNormHelper(const platform::CUDADeviceContext& ctx,
207+
const int rows, const int cols,
208+
const DropoutParam& dropout_param,
209+
const float epsilon)
210+
: FusedDropoutHelper<T, MaskType>(ctx, rows, cols, dropout_param) {
211+
using U = LayerNormParamType<T>;
212+
epsilon_ = epsilon;
213+
}
214+
215+
// call layer_norm
216+
void LayerNorm(const platform::CUDADeviceContext& ctx, const T* src,
217+
const LayerNormParamType<T>* gamma,
218+
const LayerNormParamType<T>* beta, T* out,
219+
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
220+
using U = LayerNormParamType<T>;
221+
switch (GetDesiredBlockDim(this->cols_)) {
222+
FIXED_BLOCK_DIM_CASE(
223+
LayerNormForward<
224+
T, U, kBlockDim><<<this->rows_, kBlockDim, 0, ctx.stream()>>>(
225+
src, gamma, beta, out, mean, variance, epsilon_, this->cols_));
226+
}
227+
}
228+
229+
void LayerNormGrad(const platform::CUDADeviceContext& ctx, const T* dout,
230+
const T* src, const LayerNormParamType<T>* gamma,
231+
const LayerNormParamType<T>* mean,
232+
const LayerNormParamType<T>* variance, T* d_src,
233+
LayerNormParamType<T>* d_scale,
234+
LayerNormParamType<T>* d_bias) {
235+
using U = LayerNormParamType<T>;
236+
LayerNormBackward<T, U>(src, dout, gamma, mean, variance, d_src, d_scale,
237+
d_bias, epsilon_, this->rows_, this->cols_, ctx);
238+
}
239+
240+
// out = layernorm(residual + dropout(src + bias))
241+
void LayernormResidualDropoutBias(
242+
const platform::CUDADeviceContext& ctx, const T* src, const T* residual,
243+
const T* bias, const LayerNormParamType<T>* gamma,
244+
const LayerNormParamType<T>* beta, T* dropout_out, MaskType* mask, T* out,
245+
LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) {
246+
using U = LayerNormParamType<T>;
247+
int vec_size = MAX_CACHE_BYTES / sizeof(T);
248+
if (this->cols_ % vec_size != 0) {
249+
vec_size = 1;
250+
}
251+
int threads = GetDesiredBlockDim(this->cols_ / vec_size);
252+
int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size;
253+
increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment);
254+
LaunchLayernormResidualDropoutBias<T, MaskType>(
255+
this->rows_, this->cols_, increment, this->dropout_param_.seed,
256+
this->dropout_param_.dropout_prob, epsilon_,
257+
this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test,
258+
src, residual, bias, gamma, beta, mask, dropout_out, out, mean,
259+
variance, ctx);
260+
}
261+
262+
void LayernormResidualDropoutBiasGrad(
263+
const platform::CUDADeviceContext& ctx, const T* d_out,
264+
const T* layernorm_src, const MaskType* mask,
265+
const LayerNormParamType<T>* gamma, const LayerNormParamType<T>* mean,
266+
const LayerNormParamType<T>* variance, T* d_layernorm_src,
267+
LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_layernorm_bias,
268+
T* d_dropout_src, T* d_bias, T* d_residual) {
269+
using U = LayerNormParamType<T>;
270+
LayerNormBackward<T, U>(layernorm_src, d_out, gamma, mean, variance,
271+
d_layernorm_src, d_scale, d_layernorm_bias,
272+
epsilon_, this->rows_, this->cols_, ctx);
273+
this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src,
274+
d_residual, d_bias);
275+
}
276+
277+
protected:
278+
float epsilon_;
279+
};
280+
281+
} // namespace operators
282+
} // namespace paddle

0 commit comments

Comments
 (0)