@@ -173,6 +173,36 @@ def test_sharding_stage_1_overlap_to_static(self):
173
173
for batch_id , (image , label ) in enumerate (dist_loader ()):
174
174
loss = dist_model (image , label )
175
175
176
+ def test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad (self ):
177
+ def run_sharding_test (enable_inplace_master_grad ):
178
+ os .environ ['FLAGS_enable_inplace_master_grad' ] = (
179
+ '1' if enable_inplace_master_grad else '0'
180
+ )
181
+ paddle .distributed .auto_parallel .set_mesh (self ._multi_dim_mesh )
182
+ paddle .seed (self ._seed )
183
+ model = paddle .nn .Linear (10 , 10 )
184
+ batch = paddle .rand (shape = [10 , 10 ])
185
+ batch = dist .shard_tensor (batch , self ._mesh , [dist .Shard (0 )])
186
+ opt = paddle .optimizer .AdamW (parameters = model .parameters ())
187
+ opt = dist .shard_optimizer (
188
+ opt , dist .ShardingStage1 (sharding_mesh_dim = "dp" )
189
+ )
190
+ model , opt = paddle .amp .decorate (
191
+ model , optimizers = opt , level = 'O2' , master_grad = True
192
+ )
193
+ for _ in range (5 ):
194
+ with paddle .amp .auto_cast (level = 'O2' ):
195
+ loss = model (batch )
196
+ loss .backward ()
197
+ opt .step ()
198
+ opt .clear_grad ()
199
+ return loss .numpy ()
200
+
201
+ dist .init_parallel_env ()
202
+ loss_disable = run_sharding_test (enable_inplace_master_grad = False )
203
+ loss_enable = run_sharding_test (enable_inplace_master_grad = True )
204
+ self .check_tensor_eq (loss_disable , loss_enable )
205
+
176
206
def run_test_case (self ):
177
207
if self ._backend == "cpu" :
178
208
paddle .set_device ("cpu" )
@@ -188,6 +218,7 @@ def run_test_case(self):
188
218
self .test_sharding_stage_1_to_static ()
189
219
self .test_pure_sharding_multi_mesh_stage_1 ()
190
220
self .test_sharding_stage_1_overlap_to_static ()
221
+ self .test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad ()
191
222
192
223
193
224
if __name__ == '__main__' :
0 commit comments