Skip to content

Commit 72600f3

Browse files
authored
Update semi_auto_parallel_sharding_stage_1.py
1 parent f9b39c1 commit 72600f3

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/auto_parallel/semi_auto_parallel_sharding_stage_1.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,34 @@ def run_sharding_test(enable_tensor_fusion):
234234
loss_enable = run_sharding_test(enable_tensor_fusion=True)
235235
self.check_tensor_eq(loss_disable, loss_enable)
236236

237+
def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip(
238+
self,
239+
):
240+
dist.init_parallel_env()
241+
os.environ['FLAGS_enable_inplace_master_grad'] = '1'
242+
os.environ['FLAGS_enable_tensor_fusion'] = '1'
243+
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
244+
paddle.seed(self._seed)
245+
model = paddle.nn.Linear(10, 10)
246+
batch = paddle.rand(shape=[10, 10])
247+
batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)])
248+
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
249+
opt = paddle.optimizer.AdamW(
250+
parameters=model.parameters(), grad_clip=clip
251+
)
252+
opt = dist.shard_optimizer(
253+
opt, dist.ShardingStage1(sharding_mesh_dim="dp")
254+
)
255+
model, opt = paddle.amp.decorate(
256+
model, optimizers=opt, level='O2', master_grad=True
257+
)
258+
for _ in range(5):
259+
with paddle.amp.auto_cast(level='O2'):
260+
loss = model(batch)
261+
loss.backward()
262+
opt.step()
263+
opt.clear_grad()
264+
237265
def run_test_case(self):
238266
if self._backend == "cpu":
239267
paddle.set_device("cpu")
@@ -251,6 +279,7 @@ def run_test_case(self):
251279
self.test_sharding_stage_1_overlap_to_static()
252280
self.test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad()
253281
self.test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion()
282+
self.test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip()
254283

255284

256285
if __name__ == '__main__':

0 commit comments

Comments
 (0)