Skip to content

Commit 0deaa40

Browse files
committed
fix error
1 parent 0f924da commit 0deaa40

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

paddle/fluid/operators/optimizers/lars_momentum_op.cu

+6-7
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,12 @@ __global__ void MomentumLarsKernel(
279279
rescale_grad, gridDim.x, &param_norm, &grad_norm);
280280
#else
281281
const MT rescale_grad_pow = rescale_grad * rescale_grad;
282-
MT param_parital_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
283-
MT grad_parital_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
282+
MT param_part_norm = threadIdx.x < thresh ? p_buffer[threadIdx.x] : 0;
283+
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
284284
__syncthreads();
285-
MT param_norm =
286-
Sqrt(math::blockReduceSum<MT>(param_parital_norm, FINAL_MASK));
285+
MT param_norm = Sqrt(math::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
287286
MT grad_norm = Sqrt(rescale_grad_pow *
288-
math::blockReduceSum<MT>(grad_parital_norm, FINAL_MASK));
287+
math::blockReduceSum<MT>(grad_part_norm, FINAL_MASK));
289288
#endif
290289

291290
const MT lr = learning_rate[0];
@@ -499,9 +498,9 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
499498
MT* master_param_out_data = nullptr;
500499

501500
if (multi_precision) {
502-
auto master_param = ctx.MultiInput<framework::Tensor>("MasterParam");
501+
auto master_param = ctx.MultiInput<framework::LoDTensor>("MasterParam");
503502
auto master_param_out =
504-
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
503+
ctx.MultiOutput<framework::LoDTensor>("MasterParamOut");
505504
master_param_data = master_param[0]->data<MT>();
506505
master_param_out_data =
507506
master_param_out[0]->mutable_data<MT>(ctx.GetPlace());

paddle/fluid/operators/optimizers/lars_momentum_op.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LarsMomentumOpKernel : public framework::OpKernel<T> {
4343

4444
T mu = static_cast<T>(ctx.Attr<float>("mu"));
4545
T lars_coeff = ctx.Attr<float>("lars_coeff");
46-
T lars_weight_decay = (ctx.Attr<std::vector<float>>("lars_weight_decay"))[0];
46+
T lars_weight_decay = ctx.Attr<std::vector<float>>("lars_weight_decay")[0];
4747
T epsilon = ctx.Attr<float>("epsilon");
4848

4949
auto p_out = framework::EigenVector<T>::Flatten(*(param_out[0]));

python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def net(self, main_prog, startup_prog):
5151
strategy.lars = True
5252
strategy.lars_configs = {
5353
"lars_coeff": 0.001,
54-
"lars_weight_decay": [0.0005],
54+
"lars_weight_decay": 0.0005,
5555
"epsilon": 0,
5656
"exclude_from_weight_decay": ["batch_norm", ".b"],
5757
}
@@ -134,7 +134,7 @@ def test_lars_apply_with_amp(self):
134134
strategy.lars = True
135135
strategy.lars_configs = {
136136
"lars_coeff": 0.001,
137-
"lars_weight_decay": [0.0005],
137+
"lars_weight_decay": 0.0005,
138138
"epsilon": 0,
139139
"exclude_from_weight_decay": ["batch_norm", ".b"],
140140
}

python/paddle/fluid/tests/unittests/test_momentum_op.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ def setUp(self):
286286
grads = []
287287
velocitys = []
288288
learning_rates = []
289-
master_params = []
290289
param_outs = []
291290
velocity_outs = []
291+
master_params = []
292292
master_param_outs = []
293293
for i in range(self.params_num):
294294
master_param = np.random.random((123, 321)).astype("float32")
@@ -376,8 +376,8 @@ def setUp(self):
376376
gnorm = np.sqrt(np.square(grad).sum())
377377
local_lr = learning_rate * lars_coeff * pnorm / (
378378
gnorm + lars_weight_decay[i] * param)
379-
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay[i]
380-
* param)
379+
velocity_out = mu * velocity + local_lr * (
380+
grad + lars_weight_decay[i] * param)
381381
param_out = param - velocity_out
382382

383383
params.append(("SubParam_" + str(i), param))

0 commit comments

Comments
 (0)