Skip to content

Commit f9b53f4

Browse files
committed
Fix
1 parent 3437370 commit f9b53f4

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

python/paddle/distributed/passes/auto_parallel_fused_linear_promotion.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,25 @@
5353
FUSED_LINEAR_SOURCE_PATTERNS_LIST = [
5454
# amp_level == 'o2' or 'o3'
5555
{ # only MP
56-
"forward": ["matmul_v2", "all_reduce", "elementwise_add"],
56+
"forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"],
5757
"backward": ["elementwise_add_grad", "matmul_v2_grad"],
5858
},
5959
{ # MP + SP
6060
"forward": ["matmul_v2", "reduce_scatter", "elementwise_add"],
6161
"backward": [
6262
"elementwise_add_grad",
63-
"c_allreduce_sum",
63+
"all_reduce",
6464
"scale",
6565
"all_gather",
6666
"matmul_v2_grad",
6767
"all_gather",
6868
],
6969
},
7070
{ # DP + MP
71-
"forward": ["matmul_v2", "all_reduce", "elementwise_add"],
71+
"forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"],
7272
"backward": [
7373
"elementwise_add_grad",
74-
"c_allreduce_sum",
74+
"all_reduce",
7575
"scale",
7676
"matmul_v2_grad",
7777
],
@@ -80,9 +80,9 @@
8080
"forward": ["matmul_v2", "reduce_scatter", "elementwise_add"],
8181
"backward": [
8282
"elementwise_add_grad",
83-
"c_allreduce_sum",
83+
"all_reduce",
8484
"scale",
85-
"c_allreduce_sum",
85+
"all_reduce",
8686
"scale",
8787
"all_gather",
8888
"matmul_v2_grad",
@@ -91,25 +91,25 @@
9191
},
9292
# amp_level == 'o1'
9393
{
94-
"forward": ["matmul_v2", "all_reduce", "cast", "elementwise_add"],
94+
"forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"],
9595
"backward": ["elementwise_add_grad", "matmul_v2_grad"],
9696
},
9797
{
9898
"forward": ["matmul_v2", "reduce_scatter", "cast", "elementwise_add"],
9999
"backward": [
100100
"elementwise_add_grad",
101-
"c_allreduce_sum",
101+
"all_reduce",
102102
"scale",
103103
"all_gather",
104104
"all_gather",
105105
"matmul_v2_grad",
106106
],
107107
},
108108
{
109-
"forward": ["matmul_v2", "all_reduce", "cast", "elementwise_add"],
109+
"forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"],
110110
"backward": [
111111
"elementwise_add_grad",
112-
"c_allreduce_sum",
112+
"all_reduce",
113113
"scale",
114114
"matmul_v2_grad",
115115
],
@@ -118,9 +118,9 @@
118118
"forward": ["matmul_v2", "reduce_scatter", "cast", "elementwise_add"],
119119
"backward": [
120120
"elementwise_add_grad",
121-
"c_allreduce_sum",
121+
"all_reduce",
122122
"scale",
123-
"c_allreduce_sum",
123+
"all_reduce",
124124
"scale",
125125
"all_gather",
126126
"matmul_v2_grad",

0 commit comments

Comments
 (0)