|
| 1 | +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include "paddle/fluid/framework/generator.h" |
| 18 | +#include "paddle/fluid/operators/dropout_impl_util.h" |
| 19 | +#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" |
| 20 | +#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" |
| 21 | +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" |
| 22 | +#include "paddle/fluid/operators/math/functors.h" |
| 23 | + |
| 24 | +namespace paddle { |
| 25 | +namespace operators { |
| 26 | + |
| 27 | +/** |
| 28 | + * Support two Dropouts in the use senarieo. |
| 29 | + * This warpper can be used in FFN op. |
| 30 | + * The DropoutParam will be used in the fused_dropout_act_bias, |
| 31 | + * fused_residual_dropout_bias(pre_layer_norm=ture) or |
| 32 | + * fused_layernorm_residual_dropout_bias(pre_layer_norm=false). |
| 33 | +*/ |
| 34 | +struct DropoutParam { |
| 35 | + uint64_t seed; |
| 36 | + float dropout_prob; |
| 37 | + bool is_upscale_in_train; |
| 38 | + bool is_test; |
| 39 | + bool fix_seed; |
| 40 | + int increment; |
| 41 | + const framework::Tensor* tensor_seed; |
| 42 | + int seed_val; |
| 43 | + |
| 44 | + DropoutParam() { |
| 45 | + fix_seed = false; |
| 46 | + seed = 0; |
| 47 | + is_test = false; |
| 48 | + is_upscale_in_train = false; |
| 49 | + dropout_prob = 0.5; |
| 50 | + tensor_seed = nullptr; |
| 51 | + seed_val = 0; |
| 52 | + } |
| 53 | + |
| 54 | + /** |
| 55 | + * dropout_index: can be 0, 1, 2. 0 means there is only one dropout, |
| 56 | + * 1 and 2 represent two dropout, the parameter name of dropout |
| 57 | + * will be "dropout" + dropout_index + param name, such as dropout1_seed, |
| 58 | + * dropout1_is_test. |
| 59 | + */ |
| 60 | + DropoutParam(const framework::ExecutionContext& context, |
| 61 | + const int dropout_index) { |
| 62 | + std::string pre_fix = "dropout"; |
| 63 | + std::string str_index = std::to_string(dropout_index); |
| 64 | + if (dropout_index > 0) { |
| 65 | + pre_fix = pre_fix + str_index + "_"; |
| 66 | + } else { |
| 67 | + pre_fix = pre_fix + "_"; |
| 68 | + } |
| 69 | + dropout_prob = context.Attr<float>(pre_fix + "prob"); |
| 70 | + auto& dropout_implementation = |
| 71 | + context.Attr<std::string>(pre_fix + "implementation"); |
| 72 | + is_upscale_in_train = (dropout_implementation == "upscale_in_train"); |
| 73 | + is_test = context.Attr<bool>(pre_fix + "is_test"); |
| 74 | + fix_seed = context.Attr<bool>(pre_fix + "fix_seed"); |
| 75 | + |
| 76 | + std::string str_seed = "Dropout"; |
| 77 | + if (dropout_index > 0) { |
| 78 | + str_seed = str_seed + str_index + "Seed"; |
| 79 | + } else { |
| 80 | + str_seed = str_seed + "Seed"; |
| 81 | + } |
| 82 | + tensor_seed = |
| 83 | + context.HasInput(str_seed) ? context.Input<Tensor>(str_seed) : nullptr; |
| 84 | + seed_val = context.Attr<int>(pre_fix + "seed"); |
| 85 | + } |
| 86 | + |
| 87 | + int UpdateSeedAndIncrement(const platform::CUDADeviceContext& ctx, |
| 88 | + const int offset) { |
| 89 | + uint64_t tmp_increment; |
| 90 | + GetSeedDataAndIncrement(ctx, tensor_seed, fix_seed, seed_val, offset, &seed, |
| 91 | + &tmp_increment); |
| 92 | + increment = static_cast<int>(tmp_increment); |
| 93 | + return increment; |
| 94 | + } |
| 95 | +}; |
| 96 | + |
| 97 | +template <typename T, typename MaskType> |
| 98 | +class FusedDropoutHelper { |
| 99 | + private: |
| 100 | + int GetIncrement(const platform::CUDADeviceContext& ctx) { |
| 101 | + const int VecSize = MAX_CACHE_BYTES / sizeof(T); |
| 102 | + const int real_vec_size = cols_ % VecSize == 0 ? VecSize : 1; |
| 103 | + auto config = |
| 104 | + Get1DBlocksAnd2DGrids(ctx, static_cast<uint64_t>(rows_), |
| 105 | + static_cast<uint64_t>(cols_), real_vec_size); |
| 106 | + int increment = ((cols_ - 1) / (config.thread_per_block.x * |
| 107 | + config.block_per_grid.x * real_vec_size) + |
| 108 | + 1) * |
| 109 | + real_vec_size; |
| 110 | + increment = dropout_param_.UpdateSeedAndIncrement(ctx, increment); |
| 111 | + return increment; |
| 112 | + } |
| 113 | + |
| 114 | + public: |
| 115 | + FusedDropoutHelper() {} |
| 116 | + FusedDropoutHelper(const platform::CUDADeviceContext& ctx, const int rows, |
| 117 | + const int cols, const DropoutParam& dropout_param) { |
| 118 | + rows_ = rows; |
| 119 | + cols_ = cols; |
| 120 | + dropout_param_ = dropout_param; |
| 121 | + } |
| 122 | + |
| 123 | + // out = residual + dropout( src + bias ) |
| 124 | + void ResidualDropoutBias(const platform::CUDADeviceContext& ctx, const T* src, |
| 125 | + const T* residual, const T* bias, T* out, |
| 126 | + MaskType* mask) { |
| 127 | + auto increment = GetIncrement(ctx); |
| 128 | + LaunchResidualDropoutBias<T, MaskType>( |
| 129 | + rows_, cols_, increment, dropout_param_.seed, |
| 130 | + dropout_param_.dropout_prob, dropout_param_.is_test, |
| 131 | + dropout_param_.is_upscale_in_train, src, residual, bias, mask, out, |
| 132 | + ctx); |
| 133 | + } |
| 134 | + |
| 135 | + void ResidualDropoutBiasGrad(const platform::CUDADeviceContext& ctx, |
| 136 | + const T* d_out, const MaskType* mask, T* d_src, |
| 137 | + T* d_residual, T* d_bias) { |
| 138 | + LaunchResidualDropoutBiasGrad<T, uint8_t>( |
| 139 | + d_out, mask, dropout_param_.dropout_prob, |
| 140 | + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); |
| 141 | + auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); |
| 142 | + memory::Copy(cuda_place, d_residual, cuda_place, d_out, |
| 143 | + rows_ * cols_ * sizeof(T), ctx.stream()); |
| 144 | + } |
| 145 | + |
| 146 | + // out = dropout(activation(src + bias)) |
| 147 | + void DropoutActBias(const platform::CUDADeviceContext& ctx, const T* src, |
| 148 | + const T* bias, const std::string& act_method, T* out, |
| 149 | + MaskType* mask) { |
| 150 | + auto increment = GetIncrement(ctx); |
| 151 | + if (act_method == "gelu") { |
| 152 | + GeluFunctor<T> gelu; |
| 153 | + LaunchDropoutActBias<T, MaskType, GeluFunctor<T>>( |
| 154 | + gelu, dropout_param_.seed, rows_, cols_, dropout_param_.increment, |
| 155 | + dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, |
| 156 | + dropout_param_.is_test, src, bias, out, mask, ctx); |
| 157 | + } else if (act_method == "relu") { |
| 158 | + math::ReluFunctor<T> relu; |
| 159 | + LaunchDropoutActBias<T, MaskType, math::ReluFunctor<T>>( |
| 160 | + relu, dropout_param_.seed, rows_, cols_, increment, |
| 161 | + dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, |
| 162 | + dropout_param_.is_test, src, bias, out, mask, ctx); |
| 163 | + } else { |
| 164 | + PADDLE_THROW(platform::errors::InvalidArgument( |
| 165 | + "Currently only supports gelu or relu activation functions!")); |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + void DropoutActBiasGrad(const platform::CUDADeviceContext& ctx, const T* dout, |
| 170 | + const T* src, const T* bias, const MaskType* mask, |
| 171 | + T* d_src, T* d_bias, const std::string& act_method) { |
| 172 | + if (act_method == "gelu") { |
| 173 | + GeluGradFunctor<T> gelu_grad; |
| 174 | + LaunchDropoutActBiasGrad<T, MaskType, GeluGradFunctor<T>>( |
| 175 | + gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, |
| 176 | + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); |
| 177 | + } else if (act_method == "relu") { |
| 178 | + math::ReluGradFunctor<T> relu_grad; |
| 179 | + LaunchDropoutActBiasGrad<T, MaskType, math::ReluGradFunctor<T>>( |
| 180 | + relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, |
| 181 | + dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); |
| 182 | + } else { |
| 183 | + PADDLE_THROW(platform::errors::InvalidArgument( |
| 184 | + "Currently only supports gelu or relu activation functions!")); |
| 185 | + } |
| 186 | + } |
| 187 | + |
| 188 | + protected: |
| 189 | + int rows_; |
| 190 | + int cols_; |
| 191 | + DropoutParam dropout_param_; |
| 192 | +}; |
| 193 | + |
| 194 | +template <typename T, typename MaskType> |
| 195 | +class FusedDropoutLayerNormHelper : public FusedDropoutHelper<T, MaskType> { |
| 196 | + public: |
| 197 | + FusedDropoutLayerNormHelper() {} |
| 198 | + FusedDropoutLayerNormHelper(const int rows, const int cols, |
| 199 | + const float epsilon) { |
| 200 | + using U = LayerNormParamType<T>; |
| 201 | + this->rows_ = rows; |
| 202 | + this->cols_ = cols; |
| 203 | + epsilon_ = epsilon; |
| 204 | + } |
| 205 | + |
| 206 | + FusedDropoutLayerNormHelper(const platform::CUDADeviceContext& ctx, |
| 207 | + const int rows, const int cols, |
| 208 | + const DropoutParam& dropout_param, |
| 209 | + const float epsilon) |
| 210 | + : FusedDropoutHelper<T, MaskType>(ctx, rows, cols, dropout_param) { |
| 211 | + using U = LayerNormParamType<T>; |
| 212 | + epsilon_ = epsilon; |
| 213 | + } |
| 214 | + |
| 215 | + // call layer_norm |
| 216 | + void LayerNorm(const platform::CUDADeviceContext& ctx, const T* src, |
| 217 | + const LayerNormParamType<T>* gamma, |
| 218 | + const LayerNormParamType<T>* beta, T* out, |
| 219 | + LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) { |
| 220 | + using U = LayerNormParamType<T>; |
| 221 | + switch (GetDesiredBlockDim(this->cols_)) { |
| 222 | + FIXED_BLOCK_DIM_CASE( |
| 223 | + LayerNormForward< |
| 224 | + T, U, kBlockDim><<<this->rows_, kBlockDim, 0, ctx.stream()>>>( |
| 225 | + src, gamma, beta, out, mean, variance, epsilon_, this->cols_)); |
| 226 | + } |
| 227 | + } |
| 228 | + |
| 229 | + void LayerNormGrad(const platform::CUDADeviceContext& ctx, const T* dout, |
| 230 | + const T* src, const LayerNormParamType<T>* gamma, |
| 231 | + const LayerNormParamType<T>* mean, |
| 232 | + const LayerNormParamType<T>* variance, T* d_src, |
| 233 | + LayerNormParamType<T>* d_scale, |
| 234 | + LayerNormParamType<T>* d_bias) { |
| 235 | + using U = LayerNormParamType<T>; |
| 236 | + LayerNormBackward<T, U>(src, dout, gamma, mean, variance, d_src, d_scale, |
| 237 | + d_bias, epsilon_, this->rows_, this->cols_, ctx); |
| 238 | + } |
| 239 | + |
| 240 | + // out = layernorm(residual + dropout(src + bias)) |
| 241 | + void LayernormResidualDropoutBias( |
| 242 | + const platform::CUDADeviceContext& ctx, const T* src, const T* residual, |
| 243 | + const T* bias, const LayerNormParamType<T>* gamma, |
| 244 | + const LayerNormParamType<T>* beta, T* dropout_out, MaskType* mask, T* out, |
| 245 | + LayerNormParamType<T>* mean, LayerNormParamType<T>* variance) { |
| 246 | + using U = LayerNormParamType<T>; |
| 247 | + int vec_size = MAX_CACHE_BYTES / sizeof(T); |
| 248 | + if (this->cols_ % vec_size != 0) { |
| 249 | + vec_size = 1; |
| 250 | + } |
| 251 | + int threads = GetDesiredBlockDim(this->cols_ / vec_size); |
| 252 | + int increment = ((this->cols_ - 1) / (threads * vec_size) + 1) * vec_size; |
| 253 | + increment = this->dropout_param_.UpdateSeedAndIncrement(ctx, increment); |
| 254 | + LaunchLayernormResidualDropoutBias<T, MaskType>( |
| 255 | + this->rows_, this->cols_, increment, this->dropout_param_.seed, |
| 256 | + this->dropout_param_.dropout_prob, epsilon_, |
| 257 | + this->dropout_param_.is_upscale_in_train, this->dropout_param_.is_test, |
| 258 | + src, residual, bias, gamma, beta, mask, dropout_out, out, mean, |
| 259 | + variance, ctx); |
| 260 | + } |
| 261 | + |
| 262 | + void LayernormResidualDropoutBiasGrad( |
| 263 | + const platform::CUDADeviceContext& ctx, const T* d_out, |
| 264 | + const T* layernorm_src, const MaskType* mask, |
| 265 | + const LayerNormParamType<T>* gamma, const LayerNormParamType<T>* mean, |
| 266 | + const LayerNormParamType<T>* variance, T* d_layernorm_src, |
| 267 | + LayerNormParamType<T>* d_scale, LayerNormParamType<T>* d_layernorm_bias, |
| 268 | + T* d_dropout_src, T* d_bias, T* d_residual) { |
| 269 | + using U = LayerNormParamType<T>; |
| 270 | + LayerNormBackward<T, U>(layernorm_src, d_out, gamma, mean, variance, |
| 271 | + d_layernorm_src, d_scale, d_layernorm_bias, |
| 272 | + epsilon_, this->rows_, this->cols_, ctx); |
| 273 | + this->ResidualDropoutBiasGrad(ctx, d_layernorm_src, mask, d_dropout_src, |
| 274 | + d_residual, d_bias); |
| 275 | + } |
| 276 | + |
| 277 | + protected: |
| 278 | + float epsilon_; |
| 279 | +}; |
| 280 | + |
| 281 | +} // namespace operators |
| 282 | +} // namespace paddle |
0 commit comments