|
20 | 20 | #include "paddle/phi/common/memory_utils.h"
|
21 | 21 | #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h"
|
22 | 22 | #include "paddle/phi/core/distributed/check/static_check.h"
|
23 |
| -#include "paddle/phi/core/distributed/comm_context_manager.h" |
24 | 23 | #include "paddle/phi/core/distributed/comm_task_manager.h"
|
25 | 24 | #include "paddle/phi/core/distributed/nccl_comm_task.h"
|
26 | 25 | #include "paddle/phi/core/distributed/nccl_tools.h"
|
@@ -146,7 +145,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
146 | 145 | place_to_group_key_(),
|
147 | 146 | pg_timeout_(timeout),
|
148 | 147 | nccl_comm_init_option_(nccl_comm_init_option),
|
149 |
| - allocation_stream_pairs_() { |
| 148 | + allocation_stream_pairs_(), |
| 149 | + place_to_p2p_opts_(), |
| 150 | + create_count_(0) { |
150 | 151 | LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_;
|
151 | 152 | LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ "
|
152 | 153 | << nccl_comm_init_option_;
|
@@ -948,12 +949,40 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
|
948 | 949 | platform::DeviceEvent(place, platform::GenerateDeviceEventFlag()));
|
949 | 950 | place_to_calc_ctx_.emplace(place_key, calc_ctx);
|
950 | 951 | place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));
|
| 952 | + place_to_p2p_opts_.emplace(place_key, std::move(p2p_opts)); |
951 | 953 |
|
952 | 954 | for (size_t i = 0; i < s_group_call_counter; ++i) {
|
953 | 955 | NCCL_CHECK(phi::dynload::ncclGroupStart());
|
954 | 956 | }
|
955 | 957 | }
|
956 | 958 |
|
| 959 | +void ProcessGroupNCCL::Shutdown() { |
| 960 | + for (size_t i = 0; i < s_group_call_counter; ++i) { |
| 961 | + NCCL_CHECK(phi::dynload::ncclGroupEnd()); |
| 962 | + } |
| 963 | + |
| 964 | + for (auto key_iter = place_to_group_key_.begin(); |
| 965 | + key_iter != place_to_group_key_.end(); |
| 966 | + ++key_iter) { |
| 967 | + std::string store_key = key_iter->second; |
| 968 | + auto nccl_comm_ctx = this->GetCommContext(&store_key); |
| 969 | + nccl_comm_ctx->DestroyNCCLComm(); |
| 970 | + } |
| 971 | +} |
| 972 | + |
| 973 | +void ProcessGroupNCCL::Restart() { |
| 974 | + for (auto key_iter = place_to_group_key_.begin(); |
| 975 | + key_iter != place_to_group_key_.end(); |
| 976 | + ++key_iter) { |
| 977 | + std::string place_key = key_iter->first; |
| 978 | + std::string store_key = key_iter->second; |
| 979 | + phi::distributed::P2POption p2p_opts = place_to_p2p_opts_.at(place_key); |
| 980 | + phi::distributed::CommContextManager::RecreateNCCLComm( |
| 981 | + store_, store_key, rank_, std::to_string(create_count_), &p2p_opts); |
| 982 | + create_count_++; |
| 983 | + } |
| 984 | +} |
| 985 | + |
957 | 986 | void ProcessGroupNCCL::SyncCalcStream(const Place& place,
|
958 | 987 | const std::string& place_key) {
|
959 | 988 | auto& calc_event = place_to_calc_event_.at(place_key);
|
|
0 commit comments