|
13 | 13 | // limitations under the License.
|
14 | 14 |
|
15 | 15 | #include "paddle/phi/kernels/adagrad_kernel.h"
|
16 |
| - |
17 | 16 | #include "paddle/phi/backends/gpu/gpu_context.h"
|
| 17 | +#include "paddle/phi/backends/gpu/gpu_launch_config.h" |
18 | 18 | #include "paddle/phi/backends/gpu/gpu_primitives.h"
|
| 19 | +#include "paddle/phi/common/amp_type_traits.h" |
| 20 | +#include "paddle/phi/core/dense_tensor.h" |
19 | 21 | #include "paddle/phi/core/kernel_registry.h"
|
20 | 22 | #include "paddle/phi/kernels/funcs/math_function.h"
|
21 | 23 | #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
|
22 | 24 | #include "paddle/phi/kernels/impl/adagrad_kernel_impl.h"
|
23 | 25 |
|
24 | 26 | namespace phi {
|
25 | 27 |
|
| 28 | +template <typename T, typename MT> |
| 29 | +__global__ void AdagradGPUKernel(const T* param, |
| 30 | + const T* grad, |
| 31 | + const MT* moment, |
| 32 | + const MT* lr, |
| 33 | + const MT* master_param, |
| 34 | + MT epsilon, |
| 35 | + T* param_out, |
| 36 | + MT* moment_out, |
| 37 | + MT* master_param_out, |
| 38 | + int num) { |
| 39 | + auto idx = blockDim.x * blockIdx.x + threadIdx.x; |
| 40 | + MT lr_data = static_cast<T>(lr[0]); |
| 41 | + |
| 42 | + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { |
| 43 | + MT grad_data = static_cast<MT>(grad[i]); |
| 44 | + MT moment_out_data = static_cast<MT>(moment[i]) + grad_data * grad_data; |
| 45 | + moment_out[i] = static_cast<MT>(moment_out_data); |
| 46 | + auto in = master_param_out ? master_param[i] : static_cast<MT>(param[i]); |
| 47 | + MT param_out_data = |
| 48 | + in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon); |
| 49 | + |
| 50 | + param_out[i] = static_cast<MT>(param_out_data); |
| 51 | + |
| 52 | + if (master_param_out) { |
| 53 | + master_param_out[i] = param_out_data; |
| 54 | + } |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +template <typename T> |
| 59 | +struct DenseAdagradFunctor<phi::GPUContext, T> { |
| 60 | + void operator()(const phi::GPUContext& ctx, |
| 61 | + const DenseTensor& param_t, |
| 62 | + const DenseTensor& grad_t, |
| 63 | + const DenseTensor& moment_t, |
| 64 | + const DenseTensor& learning_rate, |
| 65 | + const paddle::optional<DenseTensor>& master_param, |
| 66 | + float epsilon_t, |
| 67 | + bool multi_precision, |
| 68 | + DenseTensor* param_out_tensor, |
| 69 | + DenseTensor* moment_out_tensor, |
| 70 | + DenseTensor* master_param_outs) { |
| 71 | + using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type; |
| 72 | + T* param_out_data = ctx.template Alloc<T>(param_out_tensor); |
| 73 | + MPDType* moment_out_data = ctx.template Alloc<MPDType>(moment_out_tensor); |
| 74 | + const MPDType* master_in_data = |
| 75 | + multi_precision ? master_param->data<MPDType>() : nullptr; |
| 76 | + MPDType* master_out_data = |
| 77 | + multi_precision ? ctx.template Alloc<MPDType>(master_param_outs) |
| 78 | + : nullptr; |
| 79 | + |
| 80 | + MPDType epsilon = static_cast<MPDType>(epsilon_t); |
| 81 | + |
| 82 | + int numel = param_t.numel(); |
| 83 | + auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1); |
| 84 | + int grid = config.block_per_grid.x; |
| 85 | + int block = config.thread_per_block.x; |
| 86 | + auto stream = ctx.stream(); |
| 87 | + AdagradGPUKernel<T, MPDType> |
| 88 | + <<<block, grid, 0, stream>>>(param_t.data<T>(), |
| 89 | + grad_t.data<T>(), |
| 90 | + moment_t.data<MPDType>(), |
| 91 | + learning_rate.data<MPDType>(), |
| 92 | + master_in_data, |
| 93 | + epsilon, |
| 94 | + param_out_data, |
| 95 | + moment_out_data, |
| 96 | + master_out_data, |
| 97 | + numel); |
| 98 | + } |
| 99 | +}; |
| 100 | + |
26 | 101 | template <typename T, int block_size>
|
27 | 102 | __global__ void MergeGradKernel(const T* grad,
|
28 | 103 | const int64_t* grad_rows,
|
@@ -123,11 +198,19 @@ struct SparseAdagradFunctor<phi::GPUContext, T> {
|
123 | 198 |
|
124 | 199 | template struct SparseAdagradFunctor<phi::GPUContext, float>;
|
125 | 200 | template struct SparseAdagradFunctor<phi::GPUContext, double>;
|
| 201 | +template struct DenseAdagradFunctor<phi::GPUContext, float>; |
| 202 | +template struct DenseAdagradFunctor<phi::GPUContext, double>; |
| 203 | +template struct DenseAdagradFunctor<phi::GPUContext, phi::dtype::float16>; |
126 | 204 |
|
127 | 205 | } // namespace phi
|
128 | 206 |
|
129 |
| -PD_REGISTER_KERNEL( |
130 |
| - adagrad, GPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {} |
| 207 | +PD_REGISTER_KERNEL(adagrad, |
| 208 | + GPU, |
| 209 | + ALL_LAYOUT, |
| 210 | + phi::AdagradDenseKernel, |
| 211 | + float, |
| 212 | + double, |
| 213 | + phi::dtype::float16) {} |
131 | 214 |
|
132 | 215 | PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad,
|
133 | 216 | GPU,
|
|
0 commit comments