Skip to content

Commit 3139a0a

Browse files
[Distributed] destroy && recreate ncclComm (#72626) (#72648)
1 parent 53eb4cd commit 3139a0a

13 files changed

+275
-9
lines changed

paddle/fluid/distributed/collective/process_group_nccl.cc

+31-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "paddle/phi/common/memory_utils.h"
2121
#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
2222
#include "paddle/phi/core/distributed/check/static_check.h"
23-
#include "paddle/phi/core/distributed/comm_context_manager.h"
2423
#include "paddle/phi/core/distributed/comm_task_manager.h"
2524
#include "paddle/phi/core/distributed/nccl_comm_task.h"
2625
#include "paddle/phi/core/distributed/nccl_tools.h"
@@ -146,7 +145,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
146145
place_to_group_key_(),
147146
pg_timeout_(timeout),
148147
nccl_comm_init_option_(nccl_comm_init_option),
149-
allocation_stream_pairs_() {
148+
allocation_stream_pairs_(),
149+
place_to_p2p_opts_(),
150+
create_count_(0) {
150151
LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_;
151152
LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ "
152153
<< nccl_comm_init_option_;
@@ -948,12 +949,40 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
948949
platform::DeviceEvent(place, platform::GenerateDeviceEventFlag()));
949950
place_to_calc_ctx_.emplace(place_key, calc_ctx);
950951
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
952+
place_to_p2p_opts_.emplace(place_key, std::move(p2p_opts));
951953

952954
for (size_t i = 0; i < s_group_call_counter; ++i) {
953955
NCCL_CHECK(phi::dynload::ncclGroupStart());
954956
}
955957
}
956958

959+
void ProcessGroupNCCL::Shutdown() {
960+
for (size_t i = 0; i < s_group_call_counter; ++i) {
961+
NCCL_CHECK(phi::dynload::ncclGroupEnd());
962+
}
963+
964+
for (auto key_iter = place_to_group_key_.begin();
965+
key_iter != place_to_group_key_.end();
966+
++key_iter) {
967+
std::string store_key = key_iter->second;
968+
auto nccl_comm_ctx = this->GetCommContext(&store_key);
969+
nccl_comm_ctx->DestroyNCCLComm();
970+
}
971+
}
972+
973+
void ProcessGroupNCCL::Restart() {
974+
for (auto key_iter = place_to_group_key_.begin();
975+
key_iter != place_to_group_key_.end();
976+
++key_iter) {
977+
std::string place_key = key_iter->first;
978+
std::string store_key = key_iter->second;
979+
phi::distributed::P2POption p2p_opts = place_to_p2p_opts_.at(place_key);
980+
phi::distributed::CommContextManager::RecreateNCCLComm(
981+
store_, store_key, rank_, std::to_string(create_count_), &p2p_opts);
982+
create_count_++;
983+
}
984+
}
985+
957986
void ProcessGroupNCCL::SyncCalcStream(const Place& place,
958987
const std::string& place_key) {
959988
auto& calc_event = place_to_calc_event_.at(place_key);

paddle/fluid/distributed/collective/process_group_nccl.h

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "paddle/phi/backends/gpu/forwards.h"
2626
#include "paddle/phi/common/place.h"
2727
#include "paddle/phi/core/device_context.h"
28+
#include "paddle/phi/core/distributed/comm_context_manager.h"
2829
#include "paddle/phi/core/distributed/nccl_comm_context.h"
2930
#include "paddle/phi/core/distributed/store/store.h"
3031
#include "paddle/phi/core/platform/device_event.h"
@@ -190,6 +191,9 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
190191
phi::distributed::NCCLCommContext* GetOrCreateCommContext(
191192
const Place& place, CommType comm_type = CommType::UNKNOWN);
192193

194+
void Shutdown();
195+
void Restart();
196+
193197
private:
194198
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(const Place& place,
195199
int rank,
@@ -287,6 +291,10 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
287291
bool is_coalescing_{false};
288292
std::vector<std::shared_ptr<phi::DenseTensor>> coalescing_tensors_;
289293
std::vector<std::string> coalescing_place_keys_;
294+
295+
std::unordered_map<std::string, phi::distributed::P2POption>
296+
place_to_p2p_opts_;
297+
int64_t create_count_;
290298
};
291299

292300
} // namespace distributed

paddle/fluid/pybind/distributed_py.cc

+3-1
Original file line numberDiff line numberDiff line change
@@ -1247,7 +1247,9 @@ void BindDistributed(py::module *m) {
12471247
py::arg("nccl_comm_init_option") = 0,
12481248
py::call_guard<py::gil_scoped_release>())
12491249
.def_static("group_start", distributed::ProcessGroupNCCL::GroupStart)
1250-
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd);
1250+
.def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd)
1251+
.def("shutdown", &distributed::ProcessGroupNCCL::Shutdown)
1252+
.def("restart", &distributed::ProcessGroupNCCL::Restart);
12511253

12521254
py::class_<distributed::AsyncLoad::Task,
12531255
std::shared_ptr<distributed::AsyncLoad::Task>>(*m, "AsyncLoadTask")

paddle/phi/core/distributed/comm_context_manager.cc

+32
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,38 @@ void CommContextManager::CreateNCCLCommContext(
126126
comm_context_manager.SetStore(store);
127127
comm_context_manager.Emplace(unique_comm_key, std::move(nccl_comm_context));
128128
}
129+
130+
void CommContextManager::RecreateNCCLComm(const std::shared_ptr<Store>& store,
131+
const std::string& unique_comm_key,
132+
int rank,
133+
const std::string& hash_key,
134+
const P2POption* p2p_opt) {
135+
auto& comm_context_manager = CommContextManager::GetInstance();
136+
137+
ncclUniqueId nccl_id;
138+
if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) {
139+
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclGetUniqueId(&nccl_id));
140+
}
141+
142+
std::string unique_key = "NCCLCommContext/" + unique_comm_key + hash_key;
143+
if (rank == 0 || (p2p_opt && p2p_opt->is_p2p_op && p2p_opt->p2p_rank == 0)) {
144+
std::vector<uint8_t> nccl_id_wrapper(
145+
reinterpret_cast<uint8_t*>(&nccl_id),
146+
reinterpret_cast<uint8_t*>(&nccl_id) + NCCL_UNIQUE_ID_BYTES);
147+
store->set(unique_key, nccl_id_wrapper);
148+
} else {
149+
const auto& nccl_id_wrapper = store->get(unique_key);
150+
std::memcpy(&nccl_id, nccl_id_wrapper.data(), nccl_id_wrapper.size());
151+
}
152+
153+
VLOG(3) << "RecreateNCCLComm nccl_id: " << SerializeNCCLUniqueId(nccl_id);
154+
155+
auto comm_context = static_cast<phi::distributed::NCCLCommContext*>(
156+
comm_context_manager.Get(unique_comm_key));
157+
comm_context->CreateNCCLComm(nccl_id);
158+
159+
comm_context_manager.SetStore(store);
160+
}
129161
#endif
130162

131163
#if defined(PADDLE_WITH_GLOO)

paddle/phi/core/distributed/comm_context_manager.h

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ class CommContextManager {
8383
const std::string& hash_key = "",
8484
const P2POption* opt = nullptr,
8585
int nccl_comm_init_option = 0);
86+
static void RecreateNCCLComm(const std::shared_ptr<Store>& store,
87+
const std::string& unique_comm_key,
88+
int rank,
89+
const std::string& hash_key = "",
90+
const P2POption* opt = nullptr);
8691
#endif
8792

8893
#if defined(PADDLE_WITH_GLOO)

paddle/phi/core/distributed/nccl_comm_context.cc

+22-6
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,35 @@ NCCLCommContext::NCCLCommContext(int rank,
3333
int size,
3434
ncclUniqueId nccl_id,
3535
int nccl_comm_init_option)
36-
: CommContext(rank, size), nccl_version_(0), nccl_comm_(nullptr) {
37-
if (nccl_comm_init_option > 0 && phi::dynload::ncclCommInitRank2.IsValid()) {
36+
: CommContext(rank, size),
37+
nccl_version_(0),
38+
nccl_comm_(nullptr),
39+
nranks(size_),
40+
myrank(rank_),
41+
param(nccl_comm_init_option) {
42+
this->CreateNCCLComm(nccl_id);
43+
NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_));
44+
}
45+
46+
void NCCLCommContext::CreateNCCLComm(ncclUniqueId nccl_id) {
47+
if (param > 0 && phi::dynload::ncclCommInitRank2.IsValid()) {
3848
LOG(WARNING) << "Creating modified qp with ncclCommInitRank2.";
3949
NCCL_CHECK(phi::dynload::ncclCommInitRank2(
40-
&nccl_comm_, size_, nccl_id, rank_, nccl_comm_init_option));
50+
&nccl_comm_, nranks, nccl_id, myrank, param));
4151
} else {
42-
if (nccl_comm_init_option > 0) {
52+
if (param > 0) {
4353
LOG(WARNING) << "ncclCommInitRank2 is not supported.";
4454
}
4555
NCCL_CHECK(
46-
phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_));
56+
phi::dynload::ncclCommInitRank(&nccl_comm_, nranks, nccl_id, myrank));
57+
}
58+
}
59+
60+
void NCCLCommContext::DestroyNCCLComm() {
61+
if (nccl_comm_ != nullptr) {
62+
NCCL_CHECK(phi::dynload::ncclCommDestroy(nccl_comm_));
63+
nccl_comm_ = nullptr;
4764
}
48-
NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_));
4965
}
5066

5167
int NCCLCommContext::GetNcclVersion() { return nccl_version_; }

paddle/phi/core/distributed/nccl_comm_context.h

+8
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class NCCLCommContext final : public CommContext {
4949

5050
ncclComm_t GetNcclComm();
5151

52+
void CreateNCCLComm(ncclUniqueId nccl_id);
53+
54+
void DestroyNCCLComm();
55+
5256
gpuStream_t GetStream();
5357

5458
gpuEvent_t GetComputeEvent();
@@ -132,6 +136,10 @@ class NCCLCommContext final : public CommContext {
132136

133137
// used for compute wait comm, comm_stream-->event-->compute_stream
134138
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> comm_event_;
139+
140+
int nranks;
141+
int myrank;
142+
int param;
135143
};
136144

137145
} // namespace distributed

python/paddle/distributed/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@
7171
from .collective import (
7272
is_available,
7373
new_group,
74+
restart_process_group,
75+
shutdown_process_group,
7476
split,
7577
)
7678
from .communication import ( # noqa: F401
@@ -137,6 +139,8 @@
137139
"broadcast_object_list",
138140
"ParallelEnv",
139141
"new_group",
142+
"shutdown_process_group",
143+
"restart_process_group",
140144
"init_parallel_env",
141145
"gloo_init_parallel_env",
142146
"gloo_barrier",

python/paddle/distributed/collective.py

+58
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,61 @@ def _init_parallel_env(backend: _BackendList) -> None:
388388
core.CommContextManager.create_bkcl_comm_context(
389389
store, "0", rank, world_size, endpoints_str_hash
390390
)
391+
392+
393+
_shutdown_group_map_by_name = {}
394+
395+
396+
def _get_shutdown_group_map_by_name():
397+
global _shutdown_group_map_by_name
398+
return _shutdown_group_map_by_name
399+
400+
401+
def _update_shutdown_group_map_by_name(pg_name, group):
402+
global _shutdown_group_map_by_name
403+
_shutdown_group_map_by_name[pg_name] = group
404+
405+
406+
def _delete_shutdown_group_map_by_name(pg_name):
407+
global _shutdown_group_map_by_name
408+
del _shutdown_group_map_by_name[pg_name]
409+
410+
411+
def _clear_shutdown_group_map_by_name():
412+
global _shutdown_group_map_by_name
413+
_shutdown_group_map_by_name.clear()
414+
415+
416+
def shutdown_process_group(group: Group | None = None) -> None:
417+
shutdown_groups = _get_shutdown_group_map_by_name()
418+
419+
if group is None:
420+
global _default_group_name
421+
for pg_name, pg in _get_group_map_by_name().items():
422+
if (
423+
pg.process_group is not None
424+
and pg_name not in shutdown_groups
425+
and pg_name != _default_group_name
426+
):
427+
pg.process_group.shutdown()
428+
_update_shutdown_group_map_by_name(pg_name, pg)
429+
else:
430+
if (
431+
group.process_group is not None
432+
and group.name not in shutdown_groups
433+
):
434+
group.process_group.shutdown()
435+
_update_shutdown_group_map_by_name(group.name, group)
436+
437+
438+
def restart_process_group(group: Group | None = None) -> None:
439+
shutdown_groups = _get_shutdown_group_map_by_name()
440+
441+
if group is None:
442+
for pg in shutdown_groups.values():
443+
pg.process_group.restart()
444+
_clear_shutdown_group_map_by_name()
445+
else:
446+
if group.process_group is not None and group.name in shutdown_groups:
447+
group.process_group.restart()
448+
_delete_shutdown_group_map_by_name(group.name)

test/collective/fleet/CMakeLists.txt

+14
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,17 @@ if(LOCAL_ALL_ARCH AND (LINUX OR APPLE))
822822
"PADDLE_DIST_UT_PORT=21212;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
823823
)
824824
endif()
825+
if((WITH_GPU) AND LOCAL_ALL_PLAT)
826+
bash_test_modules(
827+
test_shutdown_process_group
828+
START_BASH
829+
../../legacy_test/dist_test.sh
830+
TIMEOUT
831+
"200"
832+
LABELS
833+
"RUN_TYPE=DIST"
834+
ENVS
835+
"PADDLE_DIST_UT_PORT=22024;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
836+
)
837+
set_tests_properties(test_shutdown_process_group PROPERTIES TIMEOUT "200")
838+
endif()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
17+
import paddle
18+
import paddle.distributed as dist
19+
20+
21+
class TestShutdownProcessGroupAPI:
22+
def __init__(self):
23+
dist.init_parallel_env()
24+
if dist.get_rank() == 0:
25+
self.data = paddle.to_tensor([[7, 8, 9], [10, 11, 12]])
26+
else:
27+
self.data = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
28+
29+
def test_shutdown_and_recreate_all(self):
30+
pg = paddle.distributed.new_group([0, 1])
31+
32+
result_base = self.data.clone()
33+
dist.all_reduce(result_base, group=pg)
34+
35+
paddle.distributed.shutdown_process_group()
36+
paddle.distributed.restart_process_group()
37+
38+
result_test = self.data.clone()
39+
dist.all_reduce(result_test, group=pg)
40+
41+
np.testing.assert_array_equal(result_base.numpy(), result_test.numpy())
42+
43+
def test_shutdown_and_recreate_single(self):
44+
pg = paddle.distributed.new_group([0, 1])
45+
46+
result_base = self.data.clone()
47+
dist.all_reduce(result_base, group=pg)
48+
49+
paddle.distributed.shutdown_process_group(pg)
50+
paddle.distributed.restart_process_group(pg)
51+
52+
result_test = self.data.clone()
53+
dist.all_reduce(result_test, group=pg)
54+
55+
np.testing.assert_array_equal(result_base.numpy(), result_test.numpy())
56+
57+
58+
if __name__ == "__main__":
59+
test_case = TestShutdownProcessGroupAPI()
60+
test_case.test_shutdown_and_recreate_all()
61+
test_case.test_shutdown_and_recreate_single()

0 commit comments

Comments
 (0)