@@ -125,7 +125,9 @@ def run_trainer(self, args):
125
125
rank = args ["trainerid" ]
126
126
current_endpoint = args ["currentendpoint" ]
127
127
nranks = 2
128
- if args ["use_comm_context" ] or args ["dynamic_static_unified_comm" ]:
128
+ if args ['static_mode' ] and (
129
+ args ["use_comm_context" ] or args ["dynamic_static_unified_comm" ]
130
+ ):
129
131
paddle .distributed .collective ._init_parallel_env (args ["backend" ])
130
132
else :
131
133
paddle .distributed .init_parallel_env ()
@@ -153,11 +155,7 @@ def run_trainer(self, args):
153
155
)
154
156
if args ["use_comm_context" ]
155
157
else (
156
- self .get_model_new_comm (
157
- train_prog , startup_prog , rank , dtype = args ['dtype' ]
158
- )
159
- if args ["dynamic_static_unified_comm" ]
160
- else self .get_model (
158
+ self .get_model (
161
159
train_prog , startup_prog , rank , dtype = args ['dtype' ]
162
160
)
163
161
)
@@ -190,8 +188,7 @@ def runtime_main(test_class, col_type):
190
188
args ["reduce_type" ] = os .getenv ("REDUCE_TYPE" )
191
189
args ["use_comm_context" ] = bool (int (os .getenv ("USE_COMM_CONTEXT" , "0" )))
192
190
args ["dynamic_static_unified_comm" ] = bool (
193
- os .getenv ("FLAGS_dynamic_static_unified_comm" , "false" ).lower ()
194
- == "true"
191
+ os .getenv ("FLAGS_dynamic_static_unified_comm" , "true" ).lower () == "true"
195
192
)
196
193
model .run_trainer (args )
197
194
@@ -352,7 +349,6 @@ def check_with_place(
352
349
"PATH_ID" : path_id ,
353
350
"DTYPE" : dtype ,
354
351
"REDUCE_TYPE" : str (reduce_type ),
355
- "FLAGS_dynamic_static_unified_comm" : "0" ,
356
352
}
357
353
required_envs .update (additional_envs )
358
354
required_envs .update (need_envs )
0 commit comments