Skip to content

Commit df68414

Browse files
committed
fix all ctest error and change lars compute code of cpu
1 parent 0deaa40 commit df68414

File tree

3 files changed

+66
-34
lines changed

3 files changed

+66
-34
lines changed

paddle/fluid/operators/optimizers/lars_momentum_op.cu

+4-8
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ __global__ void L2NormKernel(
178178
#if CUDA_VERSION >= 11000
179179
// Grid sync for completely writring partial result back to gloabl memory
180180
cg->sync();
181-
MT p_partial_sum = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
182-
MT g_partial_sum = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
181+
MT p_partial_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
182+
MT g_partial_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
183183
*p_n = sqrt(math::blockReduceSum<MT>(p_partial_sum, FINAL_MASK));
184184
*g_n = sqrt(rescale_grad_pow *
185185
math::blockReduceSum<MT>(g_partial_sum, FINAL_MASK));
@@ -193,7 +193,6 @@ struct MergedParameter {
193193
public:
194194
int64_t numel_arr[LARS_MAX_MERGED_OPS];
195195
int repeat_arr[LARS_MAX_MERGED_OPS];
196-
int thresh_arr[LARS_MAX_MERGED_OPS];
197196
const T* __restrict__ p_arr[LARS_MAX_MERGED_OPS];
198197
const T* __restrict__ g_arr[LARS_MAX_MERGED_OPS];
199198
const MT* __restrict__ v_arr[LARS_MAX_MERGED_OPS];
@@ -222,8 +221,7 @@ __global__ void MergedMomentumLarsKernel(MergedParameter<T, MT>* merged_params,
222221
MT grad_norm = static_cast<MT>(0);
223222
L2NormKernel<T, MT>(&cg, merged_params->p_arr[i], merged_params->g_arr[i],
224223
p_buffer, g_buffer, numel, merged_params->repeat_arr[i],
225-
rescale_grad, merged_params->thresh_arr[i], &param_norm,
226-
&grad_norm);
224+
rescale_grad, 0, &param_norm, &grad_norm);
227225
const MT lr = *(merged_params->lr_arr[i]);
228226
const MT lars_weight_decay = merged_params->weight_decay_arr[i];
229227
MT local_lr = lr;
@@ -418,11 +416,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
418416
for (int i = 0; i < op_num; ++i) {
419417
grid_num = (merged_params.numel_arr[i] + LARS_BLOCK_SIZE - 1) /
420418
LARS_BLOCK_SIZE;
419+
// The maximum block number for L2 norm kernel is grid_real.
421420
merged_params.repeat_arr[i] =
422421
(merged_params.numel_arr[i] + grid_stride - 1) / grid_stride - 1;
423-
// The maximum block number for L2 norm kernel is grid_real.
424-
merged_params.thresh_arr[i] =
425-
merged_params.repeat_arr[i] > 0 ? grid_real : grid_num;
426422
}
427423
if (multi_precision) {
428424
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");

paddle/fluid/operators/optimizers/lars_momentum_op.h

100755100644
+61-25
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ template <typename T>
2323
class LarsMomentumOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
26+
const bool merge_operation = ctx.Attr<bool>("merge_operation");
2627
auto param_out = ctx.MultiOutput<framework::LoDTensor>("ParamOut");
2728
auto velocity_out = ctx.MultiOutput<framework::LoDTensor>("VelocityOut");
2829
auto param = ctx.MultiInput<framework::LoDTensor>("Param");
@@ -38,39 +39,74 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
3839
framework::ToTypeName(grad_var[0]->Type())));
3940
auto grad = ctx.MultiInput<framework::LoDTensor>("Grad");
4041

41-
param_out[0]->mutable_data<T>(ctx.GetPlace());
42-
velocity_out[0]->mutable_data<T>(ctx.GetPlace());
43-
4442
T mu = static_cast<T>(ctx.Attr<float>("mu"));
4543
T lars_coeff = ctx.Attr<float>("lars_coeff");
46-
T lars_weight_decay = ctx.Attr<std::vector<float>>("lars_weight_decay")[0];
4744
T epsilon = ctx.Attr<float>("epsilon");
4845

49-
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[0]));
50-
auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[0]));
46+
if (!merge_operation) {
47+
auto* lr = learning_rate[0]->data<T>();
48+
T lars_weight_decay =
49+
ctx.Attr<std::vector<float>>("lars_weight_decay")[0];
50+
param_out[0]->mutable_data<T>(ctx.GetPlace());
51+
velocity_out[0]->mutable_data<T>(ctx.GetPlace());
52+
53+
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[0]));
54+
auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[0]));
55+
auto p = framework::EigenVector<T>::Flatten(*(param[0]));
56+
auto v = framework::EigenVector<T>::Flatten(*(velocity[0]));
57+
auto g = framework::EigenVector<T>::Flatten(*(grad[0]));
58+
59+
framework::Tensor p_norm_t, g_norm_t;
60+
p_norm_t.Resize({1});
61+
g_norm_t.Resize({1});
62+
p_norm_t.mutable_data<T>(ctx.GetPlace());
63+
g_norm_t.mutable_data<T>(ctx.GetPlace());
64+
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
65+
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
66+
ep_norm = p.square().sum().sqrt();
67+
eg_norm = g.square().sum().sqrt();
68+
69+
T local_lr = lr[0];
70+
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
71+
local_lr = lr[0] * lars_coeff * ep_norm(0) /
72+
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
73+
}
74+
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
75+
p_out = p - v_out;
76+
} else {
77+
int op_num = param.size();
78+
auto weight_decay_arr = ctx.Attr<std::vector<float>>("lars_weight_decay");
79+
for (int i = 0; i < op_num; ++i) {
80+
auto* lr = learning_rate[i]->data<T>();
81+
T lars_weight_decay = weight_decay_arr[i];
82+
param_out[i]->mutable_data<T>(ctx.GetPlace());
83+
velocity_out[i]->mutable_data<T>(ctx.GetPlace());
5184

52-
auto p = framework::EigenVector<T>::Flatten(*(param[0]));
53-
auto v = framework::EigenVector<T>::Flatten(*(velocity[0]));
54-
auto g = framework::EigenVector<T>::Flatten(*(grad[0]));
55-
auto* lr = learning_rate[0]->data<T>();
85+
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[i]));
86+
auto v_out = framework::EigenVector<T>::Flatten(*(velocity_out[i]));
87+
auto p = framework::EigenVector<T>::Flatten(*(param[i]));
88+
auto v = framework::EigenVector<T>::Flatten(*(velocity[i]));
89+
auto g = framework::EigenVector<T>::Flatten(*(grad[i]));
5690

57-
framework::Tensor p_norm_t, g_norm_t;
58-
p_norm_t.Resize({1});
59-
g_norm_t.Resize({1});
60-
p_norm_t.mutable_data<T>(ctx.GetPlace());
61-
g_norm_t.mutable_data<T>(ctx.GetPlace());
62-
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
63-
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
91+
framework::Tensor p_norm_t, g_norm_t;
92+
p_norm_t.Resize({1});
93+
g_norm_t.Resize({1});
94+
p_norm_t.mutable_data<T>(ctx.GetPlace());
95+
g_norm_t.mutable_data<T>(ctx.GetPlace());
96+
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
97+
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
98+
ep_norm = p.square().sum().sqrt();
99+
eg_norm = g.square().sum().sqrt();
64100

65-
ep_norm = p.square().sum().sqrt();
66-
eg_norm = g.square().sum().sqrt();
67-
T local_lr = lr[0];
68-
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
69-
local_lr = lr[0] * lars_coeff * ep_norm(0) /
70-
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
101+
T local_lr = lr[0];
102+
if (lars_weight_decay > 0 && ep_norm(0) > 0 && eg_norm(0) > 0) {
103+
local_lr = lr[0] * lars_coeff * ep_norm(0) /
104+
(eg_norm(0) + lars_weight_decay * ep_norm(0) + epsilon);
105+
}
106+
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
107+
p_out = p - v_out;
108+
}
71109
}
72-
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
73-
p_out = p - v_out;
74110
}
75111
};
76112

python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_lars_exclude_fn(self):
103103
'op_role_var')[0] or ".b" in op.attr('op_role_var')[0])
104104
]
105105
for op in ops_without_wd:
106-
self.assertEqual(op.attr('lars_weight_decay'), 0)
106+
self.assertEqual(op.attr('lars_weight_decay')[0], 0)
107107

108108
def test_lars_apply_with_amp(self):
109109
role = role_maker.PaddleCloudRoleMaker(is_collective=True)

0 commit comments

Comments
 (0)