Skip to content

Commit bc48eb9

Browse files
committed
add unittests
1 parent ea111a4 commit bc48eb9

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,15 @@ def _dygraph_clip(self, params_grads):
148148
x=max_global_norm,
149149
y=layers.elementwise_max(
150150
x=global_norm_var_fp32, y=max_global_norm))
151+
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)
151152
for p, g in params_grads:
152153
if g is None:
153154
continue
154155
if getattr(p, 'need_clip', True) is False:
155156
params_and_grads.append((p, g))
156157
continue
157158
if p.dtype == paddle.float16:
158-
new_grad = layers.elementwise_mul(
159-
x=g, y=paddle.cast(clip_var, paddle.float16))
159+
new_grad = layers.elementwise_mul(x=g, y=clip_var_fp16)
160160
else:
161161
new_grad = layers.elementwise_mul(x=g, y=clip_var)
162162
params_and_grads.append((p, new_grad))

python/paddle/fluid/tests/unittests/hybrid_parallel_pp_amp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def test_pp_model(self):
6161
rank_id = dist.get_rank()
6262
set_random_seed(1024, dp_id, rank_id)
6363

64+
grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0)
65+
6466
#construct model a
6567
model_a = AlexNet(10)
6668
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
6769
boundaries=[2], values=[0.001, 0.002], verbose=True)
6870
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
71+
grad_clip=grad_clip,
6972
parameters=model_a.parameters())
7073

7174
scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)
@@ -80,6 +83,7 @@ def test_pp_model(self):
8083
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
8184
boundaries=[2], values=[0.001, 0.002], verbose=True)
8285
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
86+
grad_clip=grad_clip,
8387
parameters=model_b.parameters())
8488
model_b = fleet.distributed_model(model_b)
8589
optimizer_b = fleet.distributed_optimizer(optimizer_b)

python/paddle/fluid/tests/unittests/hybrid_parallel_pp_fp16.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def test_pp_model(self):
6161
rank_id = dist.get_rank()
6262
set_random_seed(1024, dp_id, rank_id)
6363

64+
grad_clip = paddle.nn.ClipGradByGlobalNorm(1.0)
65+
6466
#construct model a
6567
model_a = AlexNet(10)
6668
scheduler_a = paddle.optimizer.lr.PiecewiseDecay(
6769
boundaries=[2], values=[0.001, 0.002], verbose=True)
6870
optimizer_a = paddle.optimizer.SGD(learning_rate=scheduler_a,
71+
grad_clip=grad_clip,
6972
parameters=model_a.parameters())
7073

7174
scaler_a = paddle.amp.GradScaler(init_loss_scaling=2**5)
@@ -75,6 +78,7 @@ def test_pp_model(self):
7578
scheduler_b = paddle.optimizer.lr.PiecewiseDecay(
7679
boundaries=[2], values=[0.001, 0.002], verbose=True)
7780
optimizer_b = paddle.optimizer.SGD(learning_rate=scheduler_b,
81+
grad_clip=grad_clip,
7882
parameters=model_b.parameters())
7983

8084
param_len = len(model_a.parameters())

0 commit comments

Comments
 (0)