Skip to content

Commit 3b5064d

Browse files
authored
update instantiate for auto parallel (#46883)
1 parent 0773639 commit 3b5064d

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

python/paddle/distributed/auto_parallel/process_group.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..collective import _new_ring_id
2222
from ...fluid.framework import _non_static_mode
2323
from ...fluid.layers.tensor import fill_constant
24-
from paddle.fluid.framework import _enable_legacy_dygraph
24+
from paddle import _legacy_C_ops
2525

2626

2727
def get_all_process_groups():
@@ -145,14 +145,15 @@ def instantiate(self):
145145
# TODO(shenliang03): This is a temporary solution to solve the problem of
146146
# hang caused by cross-creation of new_group
147147
paddle.disable_static()
148-
_enable_legacy_dygraph()
149148
paddle.set_device('gpu:%d' %
150149
paddle.distributed.ParallelEnv().dev_id)
151150
tmp = paddle.to_tensor(
152151
[1], dtype="int32") if _non_static_mode() else fill_constant(
153152
[0], dtype="int32", value="1")
154-
paddle.distributed.all_reduce(tmp, sync_op=True, group=self)
155-
paddle.distributed.wait(tmp, group=self)
153+
# use legacy ops
154+
_legacy_C_ops.c_allreduce_sum_(tmp, 'use_calc_stream', True,
155+
'ring_id', self.id)
156+
_legacy_C_ops.c_sync_calc_stream(tmp, tmp)
156157
paddle.enable_static()
157158

158159
self._is_instantiate = True

0 commit comments

Comments
 (0)