Skip to content

Commit 4779c2c

Browse files
Add multi_precision for adagrad op (#50078)
1 parent c647cac commit 4779c2c

File tree

14 files changed

+699
-74
lines changed

14 files changed

+699
-74
lines changed

paddle/fluid/operators/optimizers/adagrad_op.cc

+9
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,23 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
4343
AddInput("Grad", "(Tensor) Input gradient");
4444
AddInput("Moment", "(Tensor) Second moment");
4545
AddInput("LearningRate", "(Tensor) Learning rate");
46+
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
4647

4748
AddOutput("ParamOut", "(Tensor) Output parameter");
4849
AddOutput("MomentOut", "(Tensor) Output second moment");
50+
AddOutput("MasterParamOut",
51+
"The updated FP32 master weight for AMP. "
52+
"It shared memory with Input(MasterParam).")
53+
.AsDispensable();
4954

5055
AddAttr<float>("epsilon",
5156
"(float, default 1.0e-6) "
5257
"Constant for numerical stability")
5358
.SetDefault(1.0e-6f);
59+
AddAttr<bool>("multi_precision",
60+
"(bool, default false) "
61+
"Whether to use multi-precision during weight updating.")
62+
.SetDefault(false);
5463
AddComment(R"DOC(
5564
5665
Adaptive Gradient Algorithm (Adagrad).

paddle/fluid/pybind/eager_generator.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
205205
{"sparse_attention",
206206
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
207207
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
208+
{"adagrad", {"Param", "Grad", "Moment", "LearningRate", "MasterParam"}},
208209
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
209210
{"nce",
210211
{"Input",
@@ -361,6 +362,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
361362
"Beta2PowOut",
362363
"MasterParamOut"}},
363364
{"sgd", {"ParamOut", "MasterParamOut"}},
365+
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
364366
{"lamb",
365367
{"ParamOut",
366368
"Moment1Out",
@@ -399,7 +401,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
399401
"MasterParamOut"}},
400402
{"ftrl", {"ParamOut", "SquaredAccumOut", "LinearAccumOut"}},
401403
{"adadelta", {"ParamOut", "AvgSquaredGradOut", "AvgSquaredUpdateOut"}},
402-
{"adagrad", {"ParamOut", "MomentOut"}},
404+
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
403405
{"adamax", {"ParamOut", "MomentOut", "InfNormOut"}},
404406
{"dpsgd", {"ParamOut"}},
405407
{"decayed_adagrad", {"ParamOut", "MomentOut"}},

paddle/phi/api/yaml/legacy_ops.yaml

+6-5
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@
2929
inplace : (param -> param_out), (avg_squared_grad -> moment_out), (avg_squared_update -> inf_norm_out)
3030

3131
- op : adagrad_
32-
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, float epsilon)
33-
output : Tensor(param_out), Tensor(moment_out)
32+
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, Tensor master_param, float epsilon, bool multi_precision)
33+
output : Tensor(param_out), Tensor(moment_out), Tensor(master_param_out)
3434
infer_meta :
3535
func : AdagradInferMeta
3636
kernel :
37-
func : adagrad {dense, dense, dense, dense -> dense, dense}
38-
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense -> dense, dense}
37+
func : adagrad {dense, dense, dense, dense, dense -> dense, dense, dense}
38+
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense-> dense, dense, dense}
3939
data_type : param
40-
inplace : (param -> param_out), (moment -> moment_out)
40+
optional : master_param
41+
inplace : (param -> param_out), (moment -> moment_out), (master_param -> master_param_out)
4142

4243
- op : adam_
4344
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)

paddle/phi/infermeta/multiary.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,12 @@ void AdagradInferMeta(const MetaTensor& param,
7474
const MetaTensor& grad,
7575
const MetaTensor& moment,
7676
const MetaTensor& learning_rate,
77+
const MetaTensor& master_param,
7778
float epsilon,
79+
bool multi_precision,
7880
MetaTensor* param_out,
79-
MetaTensor* moment_out) {
81+
MetaTensor* moment_out,
82+
MetaTensor* master_param_out) {
8083
auto lr_dims = learning_rate.dims();
8184
PADDLE_ENFORCE_EQ(
8285
phi::product(lr_dims),

paddle/phi/infermeta/multiary.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ void AdagradInferMeta(const MetaTensor& param,
5353
const MetaTensor& grad,
5454
const MetaTensor& moment,
5555
const MetaTensor& learning_rate,
56+
const MetaTensor& master_param,
5657
float epsilon,
58+
bool multi_precision,
5759
MetaTensor* param_out,
58-
MetaTensor* moment_out);
60+
MetaTensor* moment_out,
61+
MetaTensor* master_param_out);
5962

6063
void AdamaxInferMeta(const MetaTensor& param,
6164
const MetaTensor& grad,

paddle/phi/kernels/adagrad_kernel.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,24 @@ void AdagradDenseKernel(const Context& dev_ctx,
2525
const DenseTensor& grad,
2626
const DenseTensor& moment,
2727
const DenseTensor& learning_rate,
28+
const paddle::optional<DenseTensor>& master_param,
2829
float epsilon,
30+
bool multi_precision,
2931
DenseTensor* param_out,
30-
DenseTensor* moment_out);
32+
DenseTensor* moment_out,
33+
DenseTensor* master_param_outs);
3134

3235
template <typename T, typename Context>
3336
void AdagradSparseKernel(const Context& dev_ctx,
3437
const DenseTensor& param,
3538
const SelectedRows& grad,
3639
const DenseTensor& moment,
3740
const DenseTensor& learning_rate,
41+
const paddle::optional<DenseTensor>& master_param,
3842
float epsilon,
43+
bool multi_precision,
3944
DenseTensor* param_out,
40-
DenseTensor* moment_out);
45+
DenseTensor* moment_out,
46+
DenseTensor* master_param_outs);
4147

4248
} // namespace phi

paddle/phi/kernels/cpu/adagrad_kernel.cc

+38
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,42 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
2828
}
2929
} // namespace
3030

31+
template <typename T>
32+
struct DenseAdagradFunctor<phi::CPUContext, T> {
33+
void operator()(const phi::CPUContext& ctx,
34+
const DenseTensor& param_t,
35+
const DenseTensor& grad_t,
36+
const DenseTensor& moment_t,
37+
const DenseTensor& learning_rate,
38+
const paddle::optional<DenseTensor>& master_param,
39+
float epsilon_t,
40+
bool multi_precision,
41+
DenseTensor* param_out_tensor,
42+
DenseTensor* moment_out_tensor,
43+
DenseTensor* master_param_outs) {
44+
ctx.template Alloc<T>(param_out_tensor);
45+
ctx.template Alloc<T>(moment_out_tensor);
46+
47+
T epsilon = static_cast<T>(epsilon_t);
48+
49+
auto param = EigenVector<T>::Flatten(param_t);
50+
51+
auto grad = EigenVector<T>::Flatten(grad_t);
52+
53+
auto moment = EigenVector<T>::Flatten(moment_t);
54+
55+
auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
56+
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
57+
auto place = *ctx.eigen_device();
58+
59+
moment_out.device(place) = moment + grad * grad;
60+
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
61+
auto* lr = learning_rate.data<T>();
62+
param_out.device(place) =
63+
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
64+
}
65+
};
66+
3167
template <typename T>
3268
struct SparseAdagradFunctor<phi::CPUContext, T> {
3369
void operator()(const phi::CPUContext& context,
@@ -67,6 +103,8 @@ struct SparseAdagradFunctor<phi::CPUContext, T> {
67103

68104
template struct SparseAdagradFunctor<phi::CPUContext, float>;
69105
template struct SparseAdagradFunctor<phi::CPUContext, double>;
106+
template struct DenseAdagradFunctor<phi::CPUContext, float>;
107+
template struct DenseAdagradFunctor<phi::CPUContext, double>;
70108

71109
} // namespace phi
72110

paddle/phi/kernels/gpu/adagrad_kernel.cu

+86-3
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,91 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/adagrad_kernel.h"
16-
1716
#include "paddle/phi/backends/gpu/gpu_context.h"
17+
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
1818
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19+
#include "paddle/phi/common/amp_type_traits.h"
20+
#include "paddle/phi/core/dense_tensor.h"
1921
#include "paddle/phi/core/kernel_registry.h"
2022
#include "paddle/phi/kernels/funcs/math_function.h"
2123
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
2224
#include "paddle/phi/kernels/impl/adagrad_kernel_impl.h"
2325

2426
namespace phi {
2527

28+
template <typename T, typename MT>
29+
__global__ void AdagradGPUKernel(const T* param,
30+
const T* grad,
31+
const MT* moment,
32+
const MT* lr,
33+
const MT* master_param,
34+
MT epsilon,
35+
T* param_out,
36+
MT* moment_out,
37+
MT* master_param_out,
38+
int num) {
39+
auto idx = blockDim.x * blockIdx.x + threadIdx.x;
40+
MT lr_data = static_cast<T>(lr[0]);
41+
42+
for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
43+
MT grad_data = static_cast<MT>(grad[i]);
44+
MT moment_out_data = static_cast<MT>(moment[i]) + grad_data * grad_data;
45+
moment_out[i] = static_cast<MT>(moment_out_data);
46+
auto in = master_param_out ? master_param[i] : static_cast<MT>(param[i]);
47+
MT param_out_data =
48+
in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon);
49+
50+
param_out[i] = static_cast<MT>(param_out_data);
51+
52+
if (master_param_out) {
53+
master_param_out[i] = param_out_data;
54+
}
55+
}
56+
}
57+
58+
template <typename T>
59+
struct DenseAdagradFunctor<phi::GPUContext, T> {
60+
void operator()(const phi::GPUContext& ctx,
61+
const DenseTensor& param_t,
62+
const DenseTensor& grad_t,
63+
const DenseTensor& moment_t,
64+
const DenseTensor& learning_rate,
65+
const paddle::optional<DenseTensor>& master_param,
66+
float epsilon_t,
67+
bool multi_precision,
68+
DenseTensor* param_out_tensor,
69+
DenseTensor* moment_out_tensor,
70+
DenseTensor* master_param_outs) {
71+
using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type;
72+
T* param_out_data = ctx.template Alloc<T>(param_out_tensor);
73+
MPDType* moment_out_data = ctx.template Alloc<MPDType>(moment_out_tensor);
74+
const MPDType* master_in_data =
75+
multi_precision ? master_param->data<MPDType>() : nullptr;
76+
MPDType* master_out_data =
77+
multi_precision ? ctx.template Alloc<MPDType>(master_param_outs)
78+
: nullptr;
79+
80+
MPDType epsilon = static_cast<MPDType>(epsilon_t);
81+
82+
int numel = param_t.numel();
83+
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
84+
int grid = config.block_per_grid.x;
85+
int block = config.thread_per_block.x;
86+
auto stream = ctx.stream();
87+
AdagradGPUKernel<T, MPDType>
88+
<<<block, grid, 0, stream>>>(param_t.data<T>(),
89+
grad_t.data<T>(),
90+
moment_t.data<MPDType>(),
91+
learning_rate.data<MPDType>(),
92+
master_in_data,
93+
epsilon,
94+
param_out_data,
95+
moment_out_data,
96+
master_out_data,
97+
numel);
98+
}
99+
};
100+
26101
template <typename T, int block_size>
27102
__global__ void MergeGradKernel(const T* grad,
28103
const int64_t* grad_rows,
@@ -123,11 +198,19 @@ struct SparseAdagradFunctor<phi::GPUContext, T> {
123198

124199
template struct SparseAdagradFunctor<phi::GPUContext, float>;
125200
template struct SparseAdagradFunctor<phi::GPUContext, double>;
201+
template struct DenseAdagradFunctor<phi::GPUContext, float>;
202+
template struct DenseAdagradFunctor<phi::GPUContext, double>;
203+
template struct DenseAdagradFunctor<phi::GPUContext, phi::dtype::float16>;
126204

127205
} // namespace phi
128206

129-
PD_REGISTER_KERNEL(
130-
adagrad, GPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {}
207+
PD_REGISTER_KERNEL(adagrad,
208+
GPU,
209+
ALL_LAYOUT,
210+
phi::AdagradDenseKernel,
211+
float,
212+
double,
213+
phi::dtype::float16) {}
131214

132215
PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad,
133216
GPU,

paddle/phi/kernels/impl/adagrad_kernel_impl.h

+35-28
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@ struct SparseAdagradFunctor {
3030
DenseTensor* param);
3131
};
3232

33+
template <typename DeviceContext, typename T>
34+
struct DenseAdagradFunctor {
35+
void operator()(const DeviceContext& ctx,
36+
const DenseTensor& param_t,
37+
const DenseTensor& grad_t,
38+
const DenseTensor& moment_t,
39+
const DenseTensor& learning_rate,
40+
const paddle::optional<DenseTensor>& master_param,
41+
float epsilon_t,
42+
bool multi_precision,
43+
DenseTensor* param_out_tensor,
44+
DenseTensor* moment_out_tensor,
45+
DenseTensor* master_param_outs);
46+
};
47+
3348
template <typename DeviceContext, typename T>
3449
phi::SelectedRows SquareSelectedRows(const DeviceContext& context,
3550
const phi::SelectedRows& input) {
@@ -50,35 +65,24 @@ void AdagradDenseKernel(const Context& ctx,
5065
const DenseTensor& grad_t,
5166
const DenseTensor& moment_t,
5267
const DenseTensor& learning_rate,
68+
const paddle::optional<DenseTensor>& master_param,
5369
float epsilon_t,
70+
bool multi_precision,
5471
DenseTensor* param_out_tensor,
55-
DenseTensor* moment_out_tensor) {
56-
ctx.template Alloc<T>(param_out_tensor);
57-
ctx.template Alloc<T>(moment_out_tensor);
58-
59-
T epsilon = static_cast<T>(epsilon_t);
60-
61-
auto param = EigenVector<T>::Flatten(param_t);
62-
63-
auto grad = EigenVector<T>::Flatten(grad_t);
64-
65-
auto moment = EigenVector<T>::Flatten(moment_t);
66-
67-
auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
68-
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
69-
auto place = *ctx.eigen_device();
70-
71-
moment_out.device(place) = moment + grad * grad;
72-
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
73-
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
74-
auto* lr = learning_rate.data<T>();
75-
param_out.device(place) =
76-
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
77-
} else {
78-
auto lr = EigenVector<T>::Flatten(learning_rate);
79-
param_out.device(place) =
80-
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
81-
}
72+
DenseTensor* moment_out_tensor,
73+
DenseTensor* master_param_outs) {
74+
DenseAdagradFunctor<Context, T> functor;
75+
functor(ctx,
76+
param_t,
77+
grad_t,
78+
moment_t,
79+
learning_rate,
80+
master_param,
81+
epsilon_t,
82+
multi_precision,
83+
param_out_tensor,
84+
moment_out_tensor,
85+
master_param_outs);
8286
}
8387

8488
template <typename T, typename Context>
@@ -87,9 +91,12 @@ void AdagradSparseKernel(const Context& ctx,
8791
const SelectedRows& grad_t,
8892
const DenseTensor& moment_t,
8993
const DenseTensor& learning_rate,
94+
const paddle::optional<DenseTensor>& master_param,
9095
float epsilon_t,
96+
bool multi_precision,
9197
DenseTensor* param_out,
92-
DenseTensor* moment_out) {
98+
DenseTensor* moment_out,
99+
DenseTensor* master_param_outs) {
93100
auto* param_out_tensor = param_out;
94101
auto* moment_out_tensor = moment_out;
95102

paddle/phi/kernels/xpu/adagrad_kernel.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ void AdagradDenseKernel(const Context& ctx,
2424
const DenseTensor& grad,
2525
const DenseTensor& moment,
2626
const DenseTensor& learning_rate,
27+
const paddle::optional<DenseTensor>& master_param,
2728
float epsilon_t,
29+
bool multi_precision,
2830
DenseTensor* param_out_tensor,
29-
DenseTensor* moment_out_tensor) {
31+
DenseTensor* moment_out_tensor,
32+
DenseTensor* master_param_outs) {
3033
ctx.template Alloc<T>(param_out_tensor);
3134
ctx.template Alloc<T>(moment_out_tensor);
3235

0 commit comments

Comments
 (0)