File tree 1 file changed +5
-4
lines changed
python/paddle/distributed/auto_parallel
1 file changed +5
-4
lines changed Original file line number Diff line number Diff line change 21
21
from ..collective import _new_ring_id
22
22
from ...fluid .framework import _non_static_mode
23
23
from ...fluid .layers .tensor import fill_constant
24
- from paddle . fluid . framework import _enable_legacy_dygraph
24
+ from paddle import _legacy_C_ops
25
25
26
26
27
27
def get_all_process_groups ():
@@ -145,14 +145,15 @@ def instantiate(self):
145
145
# TODO(shenliang03): This is a temporary solution to solve the problem of
146
146
# hang caused by cross-creation of new_group
147
147
paddle .disable_static ()
148
- _enable_legacy_dygraph ()
149
148
paddle .set_device ('gpu:%d' %
150
149
paddle .distributed .ParallelEnv ().dev_id )
151
150
tmp = paddle .to_tensor (
152
151
[1 ], dtype = "int32" ) if _non_static_mode () else fill_constant (
153
152
[0 ], dtype = "int32" , value = "1" )
154
- paddle .distributed .all_reduce (tmp , sync_op = True , group = self )
155
- paddle .distributed .wait (tmp , group = self )
153
+ # use legacy ops
154
+ _legacy_C_ops .c_allreduce_sum_ (tmp , 'use_calc_stream' , True ,
155
+ 'ring_id' , self .id )
156
+ _legacy_C_ops .c_sync_calc_stream (tmp , tmp )
156
157
paddle .enable_static ()
157
158
158
159
self ._is_instantiate = True
You can’t perform that action at this time.
0 commit comments