@@ -234,6 +234,34 @@ def run_sharding_test(enable_tensor_fusion):
234
234
loss_enable = run_sharding_test (enable_tensor_fusion = True )
235
235
self .check_tensor_eq (loss_disable , loss_enable )
236
236
237
+ def test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip (
238
+ self ,
239
+ ):
240
+ dist .init_parallel_env ()
241
+ os .environ ['FLAGS_enable_inplace_master_grad' ] = '1'
242
+ os .environ ['FLAGS_enable_tensor_fusion' ] = '1'
243
+ paddle .distributed .auto_parallel .set_mesh (self ._multi_dim_mesh )
244
+ paddle .seed (self ._seed )
245
+ model = paddle .nn .Linear (10 , 10 )
246
+ batch = paddle .rand (shape = [10 , 10 ])
247
+ batch = dist .shard_tensor (batch , self ._mesh , [dist .Shard (0 )])
248
+ clip = paddle .nn .ClipGradByGlobalNorm (1.0 )
249
+ opt = paddle .optimizer .AdamW (
250
+ parameters = model .parameters (), grad_clip = clip
251
+ )
252
+ opt = dist .shard_optimizer (
253
+ opt , dist .ShardingStage1 (sharding_mesh_dim = "dp" )
254
+ )
255
+ model , opt = paddle .amp .decorate (
256
+ model , optimizers = opt , level = 'O2' , master_grad = True
257
+ )
258
+ for _ in range (5 ):
259
+ with paddle .amp .auto_cast (level = 'O2' ):
260
+ loss = model (batch )
261
+ loss .backward ()
262
+ opt .step ()
263
+ opt .clear_grad ()
264
+
237
265
def run_test_case (self ):
238
266
if self ._backend == "cpu" :
239
267
paddle .set_device ("cpu" )
@@ -251,6 +279,7 @@ def run_test_case(self):
251
279
self .test_sharding_stage_1_overlap_to_static ()
252
280
self .test_pure_sharding_multi_mesh_stage_1_with_inplace_master_grad ()
253
281
self .test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion ()
282
+ self .test_pure_sharding_multi_mesh_stage_1_with_tensor_fusion_with_chip ()
254
283
255
284
256
285
if __name__ == '__main__' :
0 commit comments