|
53 | 53 | FUSED_LINEAR_SOURCE_PATTERNS_LIST = [
|
54 | 54 | # amp_level == 'o2' or 'o3'
|
55 | 55 | { # only MP
|
56 |
| - "forward": ["matmul_v2", "all_reduce", "elementwise_add"], |
| 56 | + "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], |
57 | 57 | "backward": ["elementwise_add_grad", "matmul_v2_grad"],
|
58 | 58 | },
|
59 | 59 | { # MP + SP
|
60 | 60 | "forward": ["matmul_v2", "reduce_scatter", "elementwise_add"],
|
61 | 61 | "backward": [
|
62 | 62 | "elementwise_add_grad",
|
63 |
| - "c_allreduce_sum", |
| 63 | + "all_reduce", |
64 | 64 | "scale",
|
65 | 65 | "all_gather",
|
66 | 66 | "matmul_v2_grad",
|
67 | 67 | "all_gather",
|
68 | 68 | ],
|
69 | 69 | },
|
70 | 70 | { # DP + MP
|
71 |
| - "forward": ["matmul_v2", "all_reduce", "elementwise_add"], |
| 71 | + "forward": ["matmul_v2", "c_allreduce_sum", "elementwise_add"], |
72 | 72 | "backward": [
|
73 | 73 | "elementwise_add_grad",
|
74 |
| - "c_allreduce_sum", |
| 74 | + "all_reduce", |
75 | 75 | "scale",
|
76 | 76 | "matmul_v2_grad",
|
77 | 77 | ],
|
|
80 | 80 | "forward": ["matmul_v2", "reduce_scatter", "elementwise_add"],
|
81 | 81 | "backward": [
|
82 | 82 | "elementwise_add_grad",
|
83 |
| - "c_allreduce_sum", |
| 83 | + "all_reduce", |
84 | 84 | "scale",
|
85 |
| - "c_allreduce_sum", |
| 85 | + "all_reduce", |
86 | 86 | "scale",
|
87 | 87 | "all_gather",
|
88 | 88 | "matmul_v2_grad",
|
|
91 | 91 | },
|
92 | 92 | # amp_level == 'o1'
|
93 | 93 | {
|
94 |
| - "forward": ["matmul_v2", "all_reduce", "cast", "elementwise_add"], |
| 94 | + "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], |
95 | 95 | "backward": ["elementwise_add_grad", "matmul_v2_grad"],
|
96 | 96 | },
|
97 | 97 | {
|
98 | 98 | "forward": ["matmul_v2", "reduce_scatter", "cast", "elementwise_add"],
|
99 | 99 | "backward": [
|
100 | 100 | "elementwise_add_grad",
|
101 |
| - "c_allreduce_sum", |
| 101 | + "all_reduce", |
102 | 102 | "scale",
|
103 | 103 | "all_gather",
|
104 | 104 | "all_gather",
|
105 | 105 | "matmul_v2_grad",
|
106 | 106 | ],
|
107 | 107 | },
|
108 | 108 | {
|
109 |
| - "forward": ["matmul_v2", "all_reduce", "cast", "elementwise_add"], |
| 109 | + "forward": ["matmul_v2", "c_allreduce_sum", "cast", "elementwise_add"], |
110 | 110 | "backward": [
|
111 | 111 | "elementwise_add_grad",
|
112 |
| - "c_allreduce_sum", |
| 112 | + "all_reduce", |
113 | 113 | "scale",
|
114 | 114 | "matmul_v2_grad",
|
115 | 115 | ],
|
|
118 | 118 | "forward": ["matmul_v2", "reduce_scatter", "cast", "elementwise_add"],
|
119 | 119 | "backward": [
|
120 | 120 | "elementwise_add_grad",
|
121 |
| - "c_allreduce_sum", |
| 121 | + "all_reduce", |
122 | 122 | "scale",
|
123 |
| - "c_allreduce_sum", |
| 123 | + "all_reduce", |
124 | 124 | "scale",
|
125 | 125 | "all_gather",
|
126 | 126 | "matmul_v2_grad",
|
|
0 commit comments