|
| 1 | +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <string> |
| 18 | + |
| 19 | +#ifdef PADDLE_WITH_CUDA |
| 20 | +#include <cuda.h> |
| 21 | +#include <curand_kernel.h> |
| 22 | +#include "paddle/fluid/platform/dynload/curand.h" |
| 23 | +#endif |
| 24 | +#ifdef PADDLE_WITH_HIP |
| 25 | +#include <hip/hip_runtime.h> |
| 26 | +#include <hiprand_kernel.h> |
| 27 | +#include "paddle/fluid/platform/dynload/hiprand.h" |
| 28 | +#endif |
| 29 | + |
| 30 | +#include "paddle/fluid/framework/eigen.h" |
| 31 | +#include "paddle/fluid/framework/generator.h" |
| 32 | +#include "paddle/fluid/framework/tensor_util.h" |
| 33 | +#include "paddle/fluid/operators/dropout_op.h" |
| 34 | +#include "paddle/fluid/platform/aligned_vector.h" |
| 35 | +#include "paddle/fluid/platform/gpu_launch_config.h" |
| 36 | + |
| 37 | +namespace paddle { |
| 38 | +namespace operators { |
| 39 | + |
| 40 | +template <typename T, typename MaskType> |
| 41 | +__global__ void RandomGenerator(const size_t n, uint64_t seed, |
| 42 | + const float dropout_prob, const T* src, |
| 43 | + MaskType* mask, T* dst, |
| 44 | + bool is_upscale_in_train, uint64_t increment) { |
| 45 | + int idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 46 | +#ifdef PADDLE_WITH_HIP |
| 47 | + hiprandStatePhilox4_32_10_t state; |
| 48 | + hiprand_init(seed, idx, increment, &state); |
| 49 | +#else |
| 50 | + curandStatePhilox4_32_10_t state; |
| 51 | + curand_init(seed, idx, increment, &state); |
| 52 | +#endif |
| 53 | + |
| 54 | + MaskType mask_val; |
| 55 | + T dst_val; |
| 56 | + T factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); |
| 57 | + for (; idx < n; idx += blockDim.x * gridDim.x) { |
| 58 | + T src_val = src[idx]; |
| 59 | +#ifdef PADDLE_WITH_HIP |
| 60 | + if (hiprand_uniform(&state) < dropout_prob) { |
| 61 | +#else |
| 62 | + if (curand_uniform(&state) < dropout_prob) { |
| 63 | +#endif |
| 64 | + mask_val = 0; |
| 65 | + dst_val = 0; |
| 66 | + } else { |
| 67 | + mask_val = 1; |
| 68 | + dst_val = is_upscale_in_train ? src_val * factor : src_val; |
| 69 | + } |
| 70 | + mask[idx] = mask_val; |
| 71 | + dst[idx] = dst_val; |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +template <typename T, typename MaskType, int VecSize> |
| 76 | +__global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, |
| 77 | + const float dropout_prob, |
| 78 | + const T* src, MaskType* mask, T* dst, |
| 79 | + bool is_upscale_in_train, |
| 80 | + uint64_t increment) { |
| 81 | + using LoadT = platform::AlignedVector<T, VecSize>; |
| 82 | + using MaskLoadT = platform::AlignedVector<MaskType, VecSize>; |
| 83 | + |
| 84 | +#ifdef PADDLE_WITH_HIP |
| 85 | + int64_t idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x; |
| 86 | + hiprandStatePhilox4_32_10_t state; |
| 87 | + hiprand_init(seed, idx, increment, &state); |
| 88 | +#else |
| 89 | + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 90 | + curandStatePhilox4_32_10_t state; |
| 91 | + curand_init(seed, idx, increment, &state); |
| 92 | +#endif |
| 93 | + |
| 94 | + T factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); |
| 95 | + for (int i = idx * VecSize; i < n; i += blockDim.x * gridDim.x * VecSize) { |
| 96 | + LoadT src_val; |
| 97 | + platform::Load<T, VecSize>(&src[i], &src_val); |
| 98 | + |
| 99 | +#ifdef PADDLE_WITH_HIP |
| 100 | + float4 rand = hiprand_uniform4(&state); |
| 101 | +#else |
| 102 | + float4 rand = curand_uniform4(&state); |
| 103 | +#endif |
| 104 | + |
| 105 | + LoadT dst_val; |
| 106 | + MaskLoadT mask_val; |
| 107 | + |
| 108 | +#pragma unroll |
| 109 | + for (int j = 0; j < VecSize; j++) { |
| 110 | + if ((&rand.x)[j] < dropout_prob) { |
| 111 | + dst_val[j] = 0; |
| 112 | + mask_val[j] = 0; |
| 113 | + } else { |
| 114 | + dst_val[j] = is_upscale_in_train ? src_val[j] * factor : src_val[j]; |
| 115 | + mask_val[j] = 1; |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + platform::Store<T, VecSize>(dst_val, &dst[i]); |
| 120 | + platform::Store<MaskType, VecSize>(mask_val, &mask[i]); |
| 121 | + } |
| 122 | +} |
| 123 | + |
| 124 | +template <typename T, typename MaskType, int VecSize> |
| 125 | +__global__ void DropoutGradCUDAKernel(const T* dout, const MaskType* mask, |
| 126 | + const T factor, const int64_t size, |
| 127 | + T* dx) { |
| 128 | + using LoadT = platform::AlignedVector<T, VecSize>; |
| 129 | + using MaskLoadT = platform::AlignedVector<MaskType, VecSize>; |
| 130 | + |
| 131 | + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 132 | + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { |
| 133 | + LoadT dout_val; |
| 134 | + platform::Load<T, VecSize>(&dout[i], &dout_val); |
| 135 | + |
| 136 | + MaskLoadT mask_val; |
| 137 | + platform::Load<MaskType, VecSize>(&mask[i], &mask_val); |
| 138 | + |
| 139 | + LoadT dx_val; |
| 140 | + |
| 141 | +#pragma unroll |
| 142 | + for (int j = 0; j < VecSize; j++) { |
| 143 | + dx_val[j] = dout_val[j] * static_cast<T>(mask_val[j]) * factor; |
| 144 | + } |
| 145 | + |
| 146 | + platform::Store<T, VecSize>(dx_val, &dx[i]); |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +template <typename T> |
| 151 | +void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, |
| 152 | + bool is_test, |
| 153 | + const std::string dropout_implementation, |
| 154 | + float dropout_prob, bool upscale_in_train, |
| 155 | + bool is_fix_seed, int seed_val, const Tensor& x, |
| 156 | + const Tensor* seed, Tensor* mask, Tensor* y) { |
| 157 | + auto& place = *dev_ctx.eigen_device(); |
| 158 | + |
| 159 | + if (!is_test) { |
| 160 | + int64_t x_numel = x.numel(); |
| 161 | + auto stream = dev_ctx.stream(); |
| 162 | + auto* mask_data = mask->data<uint8_t>(); |
| 163 | + size_t size = framework::product(mask->dims()); |
| 164 | + |
| 165 | + auto* x_data = x.data<T>(); |
| 166 | + auto* y_data = y->data<T>(); |
| 167 | + if (dropout_prob == 1.0f) { |
| 168 | +#ifdef PADDLE_WITH_HIP |
| 169 | + PADDLE_ENFORCE_CUDA_SUCCESS( |
| 170 | + hipMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); |
| 171 | + PADDLE_ENFORCE_CUDA_SUCCESS( |
| 172 | + hipMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); |
| 173 | +#else |
| 174 | + PADDLE_ENFORCE_CUDA_SUCCESS( |
| 175 | + cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); |
| 176 | + PADDLE_ENFORCE_CUDA_SUCCESS( |
| 177 | + cudaMemsetAsync(mask_data, 0, x_numel * sizeof(*mask_data), stream)); |
| 178 | +#endif |
| 179 | + return; |
| 180 | + } |
| 181 | + |
| 182 | + platform::GpuLaunchConfig config = |
| 183 | + platform::GetGpuLaunchConfig1D(dev_ctx, size); |
| 184 | + |
| 185 | + // increment is used to set the args(offset) of curand_init, which defines |
| 186 | + // offset in subsequence. |
| 187 | + // The detail: |
| 188 | + // https://docs.nvidia.com/cuda/curand/device-api-overview.html |
| 189 | + // Increment should be at least the number of curand() random numbers used |
| 190 | + // in each thread to avoid the random number generated this time being the |
| 191 | + // same as the previous calls. |
| 192 | + uint64_t seed_data; |
| 193 | + uint64_t increment; |
| 194 | + int vec_size = platform::GetVectorizedSize<T>(x_data); |
| 195 | + auto offset = ((x_numel - 1) / (config.block_per_grid.x * |
| 196 | + config.thread_per_block.x * vec_size) + |
| 197 | + 1) * |
| 198 | + vec_size; |
| 199 | + int device_id = |
| 200 | + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()).GetDeviceId(); |
| 201 | + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); |
| 202 | + |
| 203 | + if ((seed) && platform::is_gpu_place(seed->place())) { |
| 204 | + framework::Tensor seed_cpu_tensor; |
| 205 | + TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); |
| 206 | + seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]); |
| 207 | + increment = offset; |
| 208 | + } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { |
| 209 | + auto seed_offset = gen_cuda->IncrementOffset(offset); |
| 210 | + seed_data = seed_offset.first; |
| 211 | + increment = seed_offset.second; |
| 212 | + } else { |
| 213 | + if (seed) { |
| 214 | + seed_data = *(seed->data<int>()); |
| 215 | + } else { |
| 216 | + std::random_device rnd; |
| 217 | + seed_data = is_fix_seed ? seed_val : rnd(); |
| 218 | + } |
| 219 | + increment = offset; |
| 220 | + } |
| 221 | + |
| 222 | +#ifdef __HIPCC__ |
| 223 | + if (vec_size == 4 && size % 4 == 0) { |
| 224 | + hipLaunchKernelGGL( |
| 225 | + HIP_KERNEL_NAME(VectorizedRandomGenerator<T, uint8_t, 4>), |
| 226 | + config.block_per_grid, config.thread_per_block, 0, stream, size, |
| 227 | + seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train, |
| 228 | + increment); |
| 229 | + } else { |
| 230 | + hipLaunchKernelGGL(HIP_KERNEL_NAME(RandomGenerator<T, uint8_t>), |
| 231 | + config.block_per_grid, config.thread_per_block, 0, |
| 232 | + stream, size, seed_data, dropout_prob, x_data, |
| 233 | + mask_data, y_data, upscale_in_train, increment); |
| 234 | + } |
| 235 | +#else |
| 236 | + if (vec_size == 4 && size % 4 == 0) { |
| 237 | + VectorizedRandomGenerator< |
| 238 | + T, uint8_t, |
| 239 | + 4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>( |
| 240 | + size, seed_data, dropout_prob, x_data, mask_data, y_data, |
| 241 | + upscale_in_train, increment); |
| 242 | + } else { |
| 243 | + RandomGenerator<T, uint8_t><<<config.block_per_grid, |
| 244 | + config.thread_per_block, 0, stream>>>( |
| 245 | + size, seed_data, dropout_prob, x_data, mask_data, y_data, |
| 246 | + upscale_in_train, increment); |
| 247 | + } |
| 248 | +#endif |
| 249 | + } else { |
| 250 | + auto X = EigenMatrix<T>::Reshape(x, 1); |
| 251 | + auto Y = EigenMatrix<T>::Reshape(*y, 1); |
| 252 | + if (upscale_in_train) { |
| 253 | + Y.device(place) = X; |
| 254 | + } else { |
| 255 | + Y.device(place) = X * static_cast<T>(1.0f - dropout_prob); |
| 256 | + } |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +template <typename T> |
| 261 | +void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, |
| 262 | + const std::string dropout_implementation, |
| 263 | + float dropout_prob, const Tensor& grad_y, |
| 264 | + const Tensor& mask, int64_t size, |
| 265 | + Tensor* grad_x) { |
| 266 | + auto M = EigenVector<uint8_t>::Flatten(mask); |
| 267 | + auto dX = EigenVector<T>::Flatten(*grad_x); |
| 268 | + auto dY = EigenVector<T>::Flatten(grad_y); |
| 269 | + |
| 270 | + auto& place = *dev_ctx.eigen_device(); |
| 271 | + if (dropout_implementation == "upscale_in_train") { |
| 272 | + if (dropout_prob == 1.0f) { |
| 273 | + dX.device(place) = static_cast<T>(0) * dY; |
| 274 | + } else { |
| 275 | + int vec_size = platform::GetVectorizedSize<T>(grad_y.data<T>()); |
| 276 | + if (vec_size == 4 && size % 4 == 0) { |
| 277 | + auto factor = static_cast<T>(1.0f / (1.0f - dropout_prob)); |
| 278 | + auto stream = dev_ctx.stream(); |
| 279 | + platform::GpuLaunchConfig config = |
| 280 | + platform::GetGpuLaunchConfig1D(dev_ctx, size); |
| 281 | + DropoutGradCUDAKernel< |
| 282 | + T, uint8_t, |
| 283 | + 4><<<config.block_per_grid, config.thread_per_block, 0, stream>>>( |
| 284 | + grad_y.data<T>(), mask.data<uint8_t>(), factor, size, |
| 285 | + grad_x->data<T>()); |
| 286 | + } else { |
| 287 | + dX.device(place) = |
| 288 | + dY * M.cast<T>() / static_cast<T>(1.0f - dropout_prob); |
| 289 | + } |
| 290 | + } |
| 291 | + } else { |
| 292 | + dX.device(place) = dY * M.cast<T>(); |
| 293 | + } |
| 294 | +} |
| 295 | + |
| 296 | +} // namespace operators |
| 297 | +} // namespace paddle |
0 commit comments