Skip to content

Commit 2b88057

Browse files
authored
Refactor dropout cuda impl for code reuse. (#35621)
1 parent e26a250 commit 2b88057

File tree

3 files changed

+336
-251
lines changed

3 files changed

+336
-251
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
19+
#ifdef PADDLE_WITH_CUDA
20+
#include <cuda.h>
21+
#include <curand_kernel.h>
22+
#include "paddle/fluid/platform/dynload/curand.h"
23+
#endif
24+
#ifdef PADDLE_WITH_HIP
25+
#include <hip/hip_runtime.h>
26+
#include <hiprand_kernel.h>
27+
#include "paddle/fluid/platform/dynload/hiprand.h"
28+
#endif
29+
30+
#include "paddle/fluid/framework/eigen.h"
31+
#include "paddle/fluid/framework/generator.h"
32+
#include "paddle/fluid/framework/tensor_util.h"
33+
#include "paddle/fluid/operators/dropout_op.h"
34+
#include "paddle/fluid/platform/aligned_vector.h"
35+
#include "paddle/fluid/platform/gpu_launch_config.h"
36+
37+
namespace paddle {
38+
namespace operators {
39+
40+
template <typename T, typename MaskType>
41+
__global__ void RandomGenerator(const size_t n, uint64_t seed,
42+
const float dropout_prob, const T* src,
43+
MaskType* mask, T* dst,
44+
bool is_upscale_in_train, uint64_t increment) {
45+
int idx = blockDim.x * blockIdx.x + threadIdx.x;
46+
#ifdef PADDLE_WITH_HIP
47+
hiprandStatePhilox4_32_10_t state;
48+
hiprand_init(seed, idx, increment, &state);
49+
#else
50+
curandStatePhilox4_32_10_t state;
51+
curand_init(seed, idx, increment, &state);
52+
#endif
53+
54+
MaskType mask_val;
55+
T dst_val;
56+
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
57+
for (; idx < n; idx += blockDim.x * gridDim.x) {
58+
T src_val = src[idx];
59+
#ifdef PADDLE_WITH_HIP
60+
if (hiprand_uniform(&state) < dropout_prob) {
61+
#else
62+
if (curand_uniform(&state) < dropout_prob) {
63+
#endif
64+
mask_val = 0;
65+
dst_val = 0;
66+
} else {
67+
mask_val = 1;
68+
dst_val = is_upscale_in_train ? src_val * factor : src_val;
69+
}
70+
mask[idx] = mask_val;
71+
dst[idx] = dst_val;
72+
}
73+
}
74+
75+
template <typename T, typename MaskType, int VecSize>
76+
__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed,
77+
const float dropout_prob,
78+
const T* src, MaskType* mask, T* dst,
79+
bool is_upscale_in_train,
80+
uint64_t increment) {
81+
using LoadT = platform::AlignedVector<T, VecSize>;
82+
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
83+
84+
#ifdef PADDLE_WITH_HIP
85+
int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
86+
hiprandStatePhilox4_32_10_t state;
87+
hiprand_init(seed, idx, increment, &state);
88+
#else
89+
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
90+
curandStatePhilox4_32_10_t state;
91+
curand_init(seed, idx, increment, &state);
92+
#endif
93+
94+
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
95+
for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) {
96+
LoadT src_val;
97+
platform::Load<T, VecSize>(&src[i], &src_val);
98+
99+
#ifdef PADDLE_WITH_HIP
100+
float4 rand = hiprand_uniform4(&state);
101+
#else
102+
float4 rand = curand_uniform4(&state);
103+
#endif
104+
105+
LoadT dst_val;
106+
MaskLoadT mask_val;
107+
108+
#pragma unroll
109+
for (int j = 0; j < VecSize; j++) {
110+
if ((&rand.x)[j] < dropout_prob) {
111+
dst_val[j] = 0;
112+
mask_val[j] = 0;
113+
} else {
114+
dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j];
115+
mask_val[j] = 1;
116+
}
117+
}
118+
119+
platform::Store<T, VecSize>(dst_val, &dst[i]);
120+
platform::Store<MaskType, VecSize>(mask_val, &mask[i]);
121+
}
122+
}
123+
124+
template <typename T, typename MaskType, int VecSize>
125+
__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask,
126+
const T factor, const int64_t size,
127+
T* dx) {
128+
using LoadT = platform::AlignedVector<T, VecSize>;
129+
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
130+
131+
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
132+
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
133+
LoadT dout_val;
134+
platform::Load<T, VecSize>(&dout[i], &dout_val);
135+
136+
MaskLoadT mask_val;
137+
platform::Load<MaskType, VecSize>(&mask[i], &mask_val);
138+
139+
LoadT dx_val;
140+
141+
#pragma unroll
142+
for (int j = 0; j < VecSize; j++) {
143+
dx_val[j] = dout_val[j] * static_cast<T>(mask_val[j]) * factor;
144+
}
145+
146+
platform::Store<T, VecSize>(dx_val, &dx[i]);
147+
}
148+
}
149+
150+
template <typename T>
151+
void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
152+
bool is_test,
153+
const std::string dropout_implementation,
154+
float dropout_prob, bool upscale_in_train,
155+
bool is_fix_seed, int seed_val, const Tensor& x,
156+
const Tensor* seed, Tensor* mask, Tensor* y) {
157+
auto& place = *dev_ctx.eigen_device();
158+
159+
if (!is_test) {
160+
int64_t x_numel = x.numel();
161+
auto stream = dev_ctx.stream();
162+
auto* mask_data = mask->data<uint8_t>();
163+
size_t size = framework::product(mask->dims());
164+
165+
auto* x_data = x.data<T>();
166+
auto* y_data = y->data<T>();
167+
if (dropout_prob == 1.0f) {
168+
#ifdef PADDLE_WITH_HIP
169+
PADDLE_ENFORCE_CUDA_SUCCESS(
170+
hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
171+
PADDLE_ENFORCE_CUDA_SUCCESS(
172+
hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
173+
#else
174+
PADDLE_ENFORCE_CUDA_SUCCESS(
175+
cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream));
176+
PADDLE_ENFORCE_CUDA_SUCCESS(
177+
cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream));
178+
#endif
179+
return;
180+
}
181+
182+
platform::GpuLaunchConfig config =
183+
platform::GetGpuLaunchConfig1D(dev_ctx, size);
184+
185+
// increment is used to set the args(offset) of curand_init, which defines
186+
// offset in subsequence.
187+
// The detail:
188+
// https://docs.nvidia.com/cuda/curand/device-api-overview.html
189+
// Increment should be at least the number of curand() random numbers used
190+
// in each thread to avoid the random number generated this time being the
191+
// same as the previous calls.
192+
uint64_t seed_data;
193+
uint64_t increment;
194+
int vec_size = platform::GetVectorizedSize<T>(x_data);
195+
auto offset = ((x_numel - 1) / (config.block_per_grid.x *
196+
config.thread_per_block.x * vec_size) +
197+
1) *
198+
vec_size;
199+
int device_id =
200+
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId();
201+
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
202+
203+
if ((seed) && platform::is_gpu_place(seed->place())) {
204+
framework::Tensor seed_cpu_tensor;
205+
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
206+
seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
207+
increment = offset;
208+
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
209+
auto seed_offset = gen_cuda->IncrementOffset(offset);
210+
seed_data = seed_offset.first;
211+
increment = seed_offset.second;
212+
} else {
213+
if (seed) {
214+
seed_data = *(seed->data<int>());
215+
} else {
216+
std::random_device rnd;
217+
seed_data = is_fix_seed ? seed_val : rnd();
218+
}
219+
increment = offset;
220+
}
221+
222+
#ifdef __HIPCC__
223+
if (vec_size == 4 && size % 4 == 0) {
224+
hipLaunchKernelGGL(
225+
HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>),
226+
config.block_per_grid, config.thread_per_block, 0, stream, size,
227+
seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train,
228+
increment);
229+
} else {
230+
hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>),
231+
config.block_per_grid, config.thread_per_block, 0,
232+
stream, size, seed_data, dropout_prob, x_data,
233+
mask_data, y_data, upscale_in_train, increment);
234+
}
235+
#else
236+
if (vec_size == 4 && size % 4 == 0) {
237+
VectorizedRandomGenerator<
238+
T, uint8_t,
239+
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
240+
size, seed_data, dropout_prob, x_data, mask_data, y_data,
241+
upscale_in_train, increment);
242+
} else {
243+
RandomGenerator<T, uint8_t><<<config.block_per_grid,
244+
config.thread_per_block, 0, stream>>>(
245+
size, seed_data, dropout_prob, x_data, mask_data, y_data,
246+
upscale_in_train, increment);
247+
}
248+
#endif
249+
} else {
250+
auto X = EigenMatrix<T>::Reshape(x, 1);
251+
auto Y = EigenMatrix<T>::Reshape(*y, 1);
252+
if (upscale_in_train) {
253+
Y.device(place) = X;
254+
} else {
255+
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
256+
}
257+
}
258+
}
259+
260+
template <typename T>
261+
void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
262+
const std::string dropout_implementation,
263+
float dropout_prob, const Tensor& grad_y,
264+
const Tensor& mask, int64_t size,
265+
Tensor* grad_x) {
266+
auto M = EigenVector<uint8_t>::Flatten(mask);
267+
auto dX = EigenVector<T>::Flatten(*grad_x);
268+
auto dY = EigenVector<T>::Flatten(grad_y);
269+
270+
auto& place = *dev_ctx.eigen_device();
271+
if (dropout_implementation == "upscale_in_train") {
272+
if (dropout_prob == 1.0f) {
273+
dX.device(place) = static_cast<T>(0) * dY;
274+
} else {
275+
int vec_size = platform::GetVectorizedSize<T>(grad_y.data<T>());
276+
if (vec_size == 4 && size % 4 == 0) {
277+
auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
278+
auto stream = dev_ctx.stream();
279+
platform::GpuLaunchConfig config =
280+
platform::GetGpuLaunchConfig1D(dev_ctx, size);
281+
DropoutGradCUDAKernel<
282+
T, uint8_t,
283+
4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
284+
grad_y.data<T>(), mask.data<uint8_t>(), factor, size,
285+
grad_x->data<T>());
286+
} else {
287+
dX.device(place) =
288+
dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob);
289+
}
290+
}
291+
} else {
292+
dX.device(place) = dY * M.cast<T>();
293+
}
294+
}
295+
296+
} // namespace operators
297+
} // namespace paddle

0 commit comments

Comments
 (0)