Skip to content

Commit 0459942

Browse files
authored
Update semi_auto_parallel_sharding_stage_1.py
1 parent 4e656d0 commit 0459942

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

test/auto_parallel/semi_auto_parallel_sharding_stage_1.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,36 @@ def test_sharding_stage_1_overlap_to_static(self):
173173
for batch_id, (image, label) in enumerate(dist_loader()):
174174
loss = dist_model(image, label)
175175

176+
def test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad(self):
177+
def run_sharding_test(enable_inplace_master_grad):
178+
os.environ['FLAGS_enable_inplace_master_grad'] = (
179+
'1' if enable_inplace_master_grad else '0'
180+
)
181+
paddle.distributed.auto_parallel.set_mesh(self._multi_dim_mesh)
182+
paddle.seed(self._seed)
183+
model = paddle.nn.Linear(10, 10)
184+
batch = paddle.rand(shape=[10, 10])
185+
batch = dist.shard_tensor(batch, self._mesh, [dist.Shard(0)])
186+
opt = paddle.optimizer.AdamW(parameters=model.parameters())
187+
opt = dist.shard_optimizer(
188+
opt, dist.ShardingStage1(sharding_mesh_dim="dp")
189+
)
190+
model, opt = paddle.amp.decorate(
191+
model, optimizers=opt, level='O2', master_grad=True
192+
)
193+
for _ in range(5):
194+
with paddle.amp.auto_cast(level='O2'):
195+
loss = model(batch)
196+
loss.backward()
197+
opt.step()
198+
opt.clear_grad()
199+
return loss.numpy()
200+
201+
dist.init_parallel_env()
202+
loss_disable = run_sharding_test(enable_inplace_master_grad=False)
203+
loss_enable = run_sharding_test(enable_inplace_master_grad=True)
204+
self.check_tensor_eq(loss_disable, loss_enable)
205+
176206
def run_test_case(self):
177207
if self._backend == "cpu":
178208
paddle.set_device("cpu")
@@ -188,6 +218,7 @@ def run_test_case(self):
188218
self.test_sharding_stage_1_to_static()
189219
self.test_pure_sharding_multi_mesh_stage_1()
190220
self.test_sharding_stage_1_overlap_to_static()
221+
self.test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad()
191222

192223

193224
if __name__ == '__main__':

0 commit comments

Comments
 (0)