30
30
"broadcast" ,
31
31
"all_gather" ,
32
32
"c_allreduce_sum" ,
33
+ "all_reduce" ,
33
34
"c_identity" ,
34
35
]
35
36
NON_COMP_TYPE = ["while" , * COMM_OP_TYPE ]
@@ -311,7 +312,10 @@ def build_comm_desc_from_dist_op(
311
312
input_list .append ((var .dtype , shape ))
312
313
313
314
# NOTE: The input_name of comm ops used usually is X.
314
- desc ["inputs" ] = {"X" : input_list }
315
+ if op_type == "all_reduce" :
316
+ desc ["inputs" ] = {"x" : input_list }
317
+ else :
318
+ desc ["inputs" ] = {"X" : input_list }
315
319
316
320
# Get comm group by parallel_axis or the given group_ranks.
317
321
if parallel_axis is not None :
@@ -349,7 +353,10 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
349
353
desc = {}
350
354
desc ["op" ] = op_type
351
355
desc ["group_ranks" ] = group_ranks
352
- desc ["inputs" ] = {"X" : [(dtype , shape )]}
356
+ if op_type == "all_reduce" :
357
+ desc ["inputs" ] = {"x" : [(dtype , shape )]}
358
+ else :
359
+ desc ["inputs" ] = {"X" : [(dtype , shape )]}
353
360
desc ["attrs" ] = attrs
354
361
return desc
355
362
@@ -416,7 +423,7 @@ def build_dp_costs(
416
423
if not has_found :
417
424
return
418
425
419
- c_allreduce_sum_descs = build_comm_desc_from_dist_op (
426
+ all_reduce_sum_descs = build_comm_desc_from_dist_op (
420
427
"c_allreduce_sum" ,
421
428
dist_op ,
422
429
ctx ,
@@ -428,7 +435,7 @@ def build_dp_costs(
428
435
_g_op_cost_factory ["c_allreduce_sum" ],
429
436
ctx ,
430
437
processes ,
431
- c_allreduce_sum_descs ,
438
+ all_reduce_sum_descs ,
432
439
cluster ,
433
440
is_dp = True ,
434
441
)
@@ -787,17 +794,27 @@ def comm_count(self):
787
794
vars = self .op .block .vars
788
795
# NOTE: The tensor communicated input_name is "X" in default. Otherwise, this function should be overridden
789
796
try :
790
- var_name = self .op .input ("X" )[0 ]
797
+ if self .op .type != "all_reduce" :
798
+ var_name = self .op .input ("X" )[0 ]
799
+ else :
800
+ var_name = self .op .input ("x" )[0 ]
791
801
except :
792
- var_name = self .op .output ("Out" )[0 ]
802
+ if self .op .type != "all_reduce" :
803
+ var_name = self .op .output ("Out" )[0 ]
804
+ else :
805
+ var_name = self .op .output ("out" )[0 ]
793
806
var = get_var_with_recursion (
794
807
var_name , self .op .block , self .op .block .program
795
808
)
796
809
dtype = var .dtype
797
810
shape = var .shape
798
811
elif self .op_desc is not None :
799
- dtype = self .op_desc ["inputs" ]["X" ][0 ][0 ]
800
- shape = self .op_desc ["inputs" ]["X" ][0 ][1 ]
812
+ if "op" in self .op_desc and self .op_desc ["op" ] == "all_reduce" :
813
+ dtype = self .op_desc ["inputs" ]["x" ][0 ][0 ]
814
+ shape = self .op_desc ["inputs" ]["x" ][0 ][1 ]
815
+ else :
816
+ dtype = self .op_desc ["inputs" ]["X" ][0 ][0 ]
817
+ shape = self .op_desc ["inputs" ]["X" ][0 ][1 ]
801
818
802
819
factor = None
803
820
if dtype == paddle .float32 or dtype == paddle .int32 :
0 commit comments