@@ -405,13 +405,13 @@ def _shard_gradient_clip(self, main_block):
405
405
for i , sharding_info in enumerate (self .sharding_infos ):
406
406
new_op = main_block ._insert_op (
407
407
idx + i + 1 ,
408
- type = 'c_allreduce_sum ' ,
409
- inputs = {'X ' : [sum_op_output ]},
410
- outputs = {'Out ' : [sum_op_output ]},
408
+ type = 'all_reduce ' ,
409
+ inputs = {'x ' : [sum_op_output ]},
410
+ outputs = {'out ' : [sum_op_output ]},
411
411
attrs = {
412
412
'ring_id' : sharding_info .group .id ,
413
413
'op_namescope' : "/gradient_clip_model_parallelism" ,
414
- 'use_calc_stream ' : True ,
414
+ 'reduce_type ' : paddle . distributed . ReduceOp . SUM ,
415
415
OP_ROLE_KEY : OpRole .Optimize ,
416
416
},
417
417
)
@@ -535,9 +535,16 @@ def _shard_gradient_synchronization(self, main_block):
535
535
dp_ring_ids = [group .id for group in self .dp_groups ]
536
536
for idx , op in reversed (list (enumerate (main_block .ops ))):
537
537
if _is_param_grad_allreduce_op (op , main_block ):
538
- if op .type == "c_allreduce_sum" or (
539
- op .type == "reduce"
540
- and op .attr ("reduce_type" ) == dist .ReduceOp .SUM
538
+ if (
539
+ op .type == "c_allreduce_sum"
540
+ or (
541
+ op .type == "all_reduce"
542
+ and op .attr ("reduce_type" ) == dist .ReduceOp .SUM
543
+ )
544
+ or (
545
+ op .type == "reduce"
546
+ and op .attr ("reduce_type" ) == dist .ReduceOp .SUM
547
+ )
541
548
):
542
549
reduce_op_type = "reduce"
543
550
reduce_type = dist .ReduceOp .SUM
@@ -1036,7 +1043,13 @@ def op_depend_on_group(op, group):
1036
1043
cur_group .is_in_local_shard = True
1037
1044
assert ops [i + 1 ].type in [
1038
1045
"c_allreduce_sum" ,
1039
- ], "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
1046
+ ] or (
1047
+ ops [i + 1 ].type == 'all_reduce'
1048
+ and ops [i + 1 ].attr ('reduce_type' )
1049
+ in [
1050
+ paddle .distributed .ReduceOp .SUM ,
1051
+ ]
1052
+ ), "Sharding should reduce grad first and than allreduce if Hybrid Sharding with Data-Parallel"
1040
1053
assert (
1041
1054
ops [i + 1 ].output_arg_names [0 ] == grad_name
1042
1055
), "Hybrid Sharding with Data-Parallel should sync same gradient var"
@@ -1236,7 +1249,13 @@ def _overlap_grad_comm(
1236
1249
grad_comm_op_to_stream_idx = {}
1237
1250
for idx , op in enumerate (ops ):
1238
1251
if is_data_parallel_reduce_op (op ):
1239
- if op .type in ["c_allreduce_sum" ]:
1252
+ if op .type in ["c_allreduce_sum" ] or (
1253
+ op .type == 'all_reduce'
1254
+ and op .attr ('reduce_type' )
1255
+ in [
1256
+ paddle .distributed .ReduceOp .SUM ,
1257
+ ]
1258
+ ):
1240
1259
continue
1241
1260
stream_idx = reduce_op_count % self .grad_comm_stream_num
1242
1261
grad_comm_op_to_stream_idx [op ] = stream_idx
@@ -1291,7 +1310,13 @@ def _overlap_grad_comm(
1291
1310
next_op = ops [idx + 1 ]
1292
1311
assert next_op .type in [
1293
1312
"c_allreduce_sum" ,
1294
- ]
1313
+ ] or (
1314
+ next_op .type == 'all_reduce'
1315
+ and next_op .attr ('reduce_type' )
1316
+ in [
1317
+ paddle .distributed .ReduceOp .SUM ,
1318
+ ]
1319
+ )
1295
1320
assert next_op .output ("Out" )[0 ] == reduce_varname
1296
1321
# FIXME hybrid sharding-dp support multi comm & stream in feature
1297
1322
# next_op._set_attr("ring_id", comm_group.id)
0 commit comments