Skip to content

Commit c9b186e

Browse files
committed
fix bugs and test distributed api functions
1 parent 85baaf7 commit c9b186e

File tree

5 files changed

+63
-24
lines changed

5 files changed

+63
-24
lines changed

paddle/fluid/distributed/collective/process_group_flagcx.cc

+48-15
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ flagcxComm_t ProcessGroupFlagcx::FlagcxComm(const Place& place) const {
222222

223223
phi::distributed::FlagcxCommContext* ProcessGroupFlagcx::GetOrCreateCommContext(
224224
const Place& place, CommType comm_type) {
225+
VLOG(3) << "flagcx debug: entered ProcessGroupFlagcx::GetOrCreateCommContext";
225226
const auto& key = GetKeyFromPlace(place);
226227
std::string store_key;
227228
GetStoreKey(key, comm_type, &store_key);
@@ -273,8 +274,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::AllReduce(
273274
const AllreduceOptions& opts,
274275
bool sync_op,
275276
bool use_calc_stream) {
277+
VLOG(3) << "flagcx debug: entered ProcessGroupFlagcx::AllReduce" <<
278+
" sync_op: " << sync_op << "use_calc_stream: " << use_calc_stream;
276279
CheckTensorContiguous(in_tensor);
277280
CheckTensorContiguous(*out_tensor);
281+
VLOG(3) << "flagcx debug: finished checking input and output tensor";
278282

279283
return Collective(
280284
[&](phi::distributed::FlagcxCommContext* comm_context, flagcxStream_t stream) {
@@ -749,13 +753,15 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
749753
const std::string& store_key,
750754
CommType comm_type,
751755
int p2p_rank) {
756+
VLOG(3) << "flagcx debug: entered ProcessGroupFlagcx::CreateFlagcxEnvCache";
752757
//TODO(changtao): we only support one flagcx comm ctx
753758
if (flagcx_comm_ != nullptr) {
754759
return;
755760
}
756761
VLOG(3) << "init flagcx rank_in_group: " << rank_ << ", nranks: " << size_
757762
<< ", gid: " << gid_ << ", place key: " << place_key
758763
<< ", store_key: " << store_key;
764+
store_key_ = store_key;
759765

760766
// for (size_t i = 0; i < s_group_call_counter; ++i) {
761767
// NCCL_CHECK(phi::dynload::ncclGroupEnd());
@@ -770,6 +776,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
770776
// NCCL_CHECK(phi::dynload::ncclGroupStart());
771777

772778
// phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank});
779+
VLOG(3) << "flagcx debug: before CreateFlagcxCommContext";
773780
phi::distributed::CommContextManager::CreateFlagcxCommContext(
774781
store_, store_key, rank_, size_, "");
775782

@@ -779,7 +786,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
779786
VLOG(3) << "Get flagcx comm: " << flagcx_comm_ctx->GetFlagcxComm();
780787
// << " for place_key: " << place_key << " on rank_in_group: " << rank
781788
// << " nranks: " << num_ranks << " gid: " << gid_;
782-
789+
VLOG(3) << "flagcx debug: get flagcx comm";
783790
flagcx_comm_ = flagcx_comm_ctx->GetFlagcxComm();
784791
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
785792
// comm_ctx->set_nccl_comm(flagcx_comm_ctx->GetFlagcxComm());
@@ -831,6 +838,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
831838
auto* calc_ctx = static_cast<phi::GPUContext*>(
832839
phi::DeviceContextPool::Instance().Get(place));
833840

841+
VLOG(3) << "flagcx debug: adding key to maps";
834842
place_to_calc_event_.emplace(
835843
place_key,
836844
platform::DeviceEvent(place, platform::GenerateDeviceEventFlag()));
@@ -910,45 +918,59 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
910918
CommType comm_type,
911919
bool sync_op,
912920
bool use_calc_stream) {
921+
VLOG(3) << "flagcx debug: Entered ProcessGroupFlagcx::Collective";
913922
CheckTensorContiguous(tensor);
923+
VLOG(3) << "flagcx debug: finished checking tensor in Collective API";
914924

915925
comm_seq_++;
916926
const auto& place = tensor.place();
927+
VLOG(3) << "flagcx debug: getting key from place";
917928
const auto& key = GetKeyFromPlace(place);
918929

930+
VLOG(3) << "flagcx debug: adding cuda guard to device";
919931
platform::CUDADeviceGuard cuda_guard(place);
920932

921933
std::string store_key;
934+
VLOG(3) << "flagcx debug: getting store key";
922935
GetStoreKey(key, comm_type, &store_key);
923936

924937
if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) {
938+
VLOG(3) << "flagcx debug: creating flagcx env cache";
925939
CreateFlagcxEnvCache(place, key, store_key, comm_type);
926940
}
927941

928942
if (!use_calc_stream) {
943+
VLOG(3) << "flagcx debug: syncing calc stream";
929944
SyncCalcStream(place, key);
930945
}
931946

932947
auto task =
933948
CreateTask(place, rank_, comm_type, sync_op, use_calc_stream, gid_);
934949

935-
const auto* calc_ctx = place_to_calc_ctx_.at(key);
950+
VLOG(3) << "flagcx debug: getting comm context";
936951
const auto& comm_ctx = place_to_comm_ctx_.at(key);
952+
VLOG(3) << "flagcx debug: getting calc context";
953+
const auto* calc_ctx = place_to_calc_ctx_.at(key);
937954
// auto nccl_comm = comm_ctx->nccl_comm();
938955
// auto flagcx_stream = use_calc_stream ? (flagcxStream_t)&calc_ctx->stream() : (flagcxStream_t)&comm_ctx->stream();
956+
957+
VLOG(3) << "flagcx debug: getting comm context";
958+
auto flagcx_comm_ctx = this->GetCommContext(&store_key);
959+
939960
flagcxStream_t flagcx_stream;
940961
if (use_calc_stream) {
941962
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
963+
VLOG(3) << "flagcx debug: getting calc stream";
942964
auto calc_stream = calc_ctx->stream();
943-
flagcx_stream = (flagcxStream_t)&calc_stream;
965+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy(&flagcx_stream, (void *)&calc_stream);
944966
} else {
967+
VLOG(3) << "flagcx debug: getting comm stream";
945968
auto comm_stream = comm_ctx->stream();
946-
flagcx_stream = (flagcxStream_t)&comm_stream;
969+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy(&flagcx_stream, (void *)&comm_stream);
947970
}
948971

949-
auto flagcx_comm_ctx = this->GetCommContext(&store_key);
950-
951972
if (!FLAGS_enable_async_trace) {
973+
VLOG(3) << "flagcx debug: calling function";
952974
fn(flagcx_comm_ctx, flagcx_stream);
953975
} else {
954976
// std::string group_key = place_to_group_key_.at(key);
@@ -981,9 +1003,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
9811003
// FLAGS_use_cuda_malloc_async_allocator) {
9821004
// memory::RecordStream(tensor.Holder(), nccl_stream);
9831005
// }
1006+
VLOG(3) << "flagcx debug: not coalescing, updating wait chain";
9841007
task->UpdateWaitChain(*comm_ctx);
985-
allocation_stream_pairs_.emplace_back(tensor.Holder(), flagcx_stream);
1008+
allocation_stream_pairs_.emplace_back(tensor.Holder(), *(gpuStream_t*)flagcx_stream);
9861009
} else {
1010+
VLOG(3) << "flagcx debug: coalescing tensors";
9871011
coalescing_tensors_.emplace_back(
9881012
std::make_shared<phi::DenseTensor>(tensor));
9891013
coalescing_place_keys_.push_back(key);
@@ -996,6 +1020,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
9961020
// }
9971021

9981022
if (sync_op) {
1023+
VLOG(3) << "flagcx debug: task wait";
9991024
task->Wait();
10001025
}
10011026

@@ -1006,6 +1031,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
10061031
// PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
10071032
// #endif
10081033

1034+
VLOG(3) << "flagcx debug: free flagcx tmp stream";
1035+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream);
10091036

10101037
return task;
10111038
}
@@ -1064,19 +1091,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
10641091
const auto* calc_ctx = place_to_calc_ctx_.at(key);
10651092
const auto& comm_ctx = place_to_comm_ctx_.at(key);
10661093

1094+
auto flagcx_comm_ctx = this->GetCommContext(&store_key);
1095+
10671096
// auto nccl_comm = comm_ctx->nccl_comm();
10681097
// auto flagcx_stream = use_calc_stream ? (flagcxStream_t)&calc_ctx->stream() : (flagcxStream_t)&comm_ctx->stream();
10691098
flagcxStream_t flagcx_stream;
10701099
if (use_calc_stream) {
10711100
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
1101+
VLOG(3) << "flagcx debug: getting calc stream";
10721102
auto calc_stream = calc_ctx->stream();
1073-
flagcx_stream = (flagcxStream_t)&calc_stream;
1103+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy(&flagcx_stream, (void *)&calc_stream);
10741104
} else {
1105+
VLOG(3) << "flagcx debug: getting comm stream";
10751106
auto comm_stream = comm_ctx->stream();
1076-
flagcx_stream = (flagcxStream_t)&comm_stream;
1107+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy(&flagcx_stream, (void *)&comm_stream);
10771108
}
10781109

1079-
std::string group_key = place_to_group_key_.at(key);
1110+
// std::string group_key = place_to_group_key_.at(key);
10801111
// auto comm_task =
10811112
// std::make_shared<phi::distributed::FlagcxCommTask>(place,
10821113
// group_key,
@@ -1092,7 +1123,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
10921123
// comm_type,
10931124
// pg_timeout_);
10941125

1095-
auto flagcx_comm_ctx = this->GetCommContext(&store_key);
10961126

10971127
if (!FLAGS_enable_async_trace) {
10981128
fn(flagcx_comm_ctx, flagcx_stream, p2p_target_rank);
@@ -1113,7 +1143,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
11131143
// memory::RecordStream(tensor.Holder(), nccl_stream);
11141144
// }
11151145
task->UpdateWaitChain(*comm_ctx);
1116-
allocation_stream_pairs_.emplace_back(tensor.Holder(), flagcx_stream);
1146+
allocation_stream_pairs_.emplace_back(tensor.Holder(), *(gpuStream_t*)flagcx_stream);
11171147
} else {
11181148
coalescing_tensors_.emplace_back(
11191149
std::make_shared<phi::DenseTensor>(tensor));
@@ -1137,7 +1167,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
11371167
// PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
11381168
// #endif
11391169

1140-
1170+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream);
11411171
return task;
11421172
}
11431173

@@ -1205,17 +1235,20 @@ void ProcessGroupFlagcx::EndCoalescing(
12051235
const auto& tensor = coalescing_tensors_[i];
12061236
const auto& key = coalescing_place_keys_[i];
12071237
const auto& comm_ctx = place_to_comm_ctx_.at(key);
1238+
auto flagcx_comm_ctx = this->GetCommContext(&store_key_);
12081239
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
12091240
auto comm_stream = comm_ctx->stream();
1210-
auto flagcx_stream = (flagcxStream_t)&comm_stream;
1241+
flagcxStream_t flagcx_stream;
1242+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy(&flagcx_stream, (void *)&comm_stream);
12111243

12121244
// if (FLAGS_use_stream_safe_cuda_allocator ||
12131245
// FLAGS_use_cuda_malloc_async_allocator) {
12141246
// memory::RecordStream(tensor->Holder(), nccl_stream);
12151247
// }
12161248

12171249
flagcx_task->UpdateWaitChain(*comm_ctx);
1218-
allocation_stream_pairs_.emplace_back(tensor->Holder(), flagcx_stream);
1250+
allocation_stream_pairs_.emplace_back(tensor->Holder(), *(gpuStream_t*)flagcx_stream);
1251+
flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream);
12191252
}
12201253

12211254
is_coalescing_ = false;

paddle/fluid/distributed/collective/process_group_flagcx.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,10 @@ class ProcessGroupFlagcx final : public ProcessGroupWithStream {
270270
int flagcx_comm_init_option_;
271271

272272
// optimize memory for process_group
273-
std::vector<std::pair<std::weak_ptr<phi::Allocation>, flagcxStream_t>>
273+
std::vector<std::pair<std::weak_ptr<phi::Allocation>, gpuStream_t>>
274274
allocation_stream_pairs_;
275-
flagcxComm_t flagcx_comm_;
275+
flagcxComm_t flagcx_comm_{nullptr};
276+
std::string store_key_;
276277

277278
// For coalescing tensors processing (eg. batch_isend_irecv)
278279
bool is_coalescing_{false};

paddle/phi/core/distributed/comm_context_manager.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,10 @@ void CommContextManager::CreateFlagcxCommContext(const std::shared_ptr<Store>& s
253253
return;
254254
}
255255
flagcxHandlerGroup_t flagcx_handler;
256-
// phi::dynload::flagcxHandleInit(&flagcx_handler);
257-
flagcxHandleInit(&flagcx_handler);
256+
VLOG(3) << "flagcx debug: flagcxHendleInit";
257+
phi::dynload::flagcxHandleInit(&flagcx_handler);
258258
if (rank == 0) {
259-
// phi::dynload::flagcxGetUniqueId(&flagcx_handler->uniqueId);
260-
flagcxGetUniqueId(&flagcx_handler->uniqueId);
259+
phi::dynload::flagcxGetUniqueId(&flagcx_handler->uniqueId);
261260
}
262261

263262
std::string unique_key = "FlagcxCommContext/" + unique_comm_key + hash_key;
@@ -299,8 +298,11 @@ void CommContextManager::CreateFlagcxCommContext(const std::shared_ptr<Store>& s
299298
// flagcx_comm_context->SetComputeEvent(std::move(compute_event));
300299
// flagcx_comm_context->SetCommEvent(std::move(comm_event));
301300
// }
301+
comm_context_manager.SetStore(store);
302+
comm_context_manager.Emplace(unique_comm_key, std::move(flagcx_comm_context));
302303
}
303304
#endif
305+
304306
CommContext* CommContextManager::Emplace(
305307
const std::string& unique_comm_key,
306308
std::unique_ptr<CommContext> comm_context) {

paddle/phi/core/distributed/flagcx_comm_context.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,17 @@ class FlagcxCommContext final : public CommContext {
8080

8181
int flagcx_version_;
8282

83-
flagcxHandlerGroup_t flagcx_handler_;
84-
8583
std::unique_ptr<phi::GPUContext> dev_ctx_;
8684

8785
// used for comm wait compute, compute_stream-->event-->comm_stream
8886
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> compute_event_;
8987

9088
// used for compute wait comm, comm_stream-->event-->compute_stream
9189
std::shared_ptr<std::remove_pointer<phi::gpuEvent_t>::type> comm_event_;
90+
91+
public:
92+
flagcxHandlerGroup_t flagcx_handler_;
93+
9294
};
9395

9496
} // namespace distributed

python/paddle/distributed/collective.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ def _new_process_group_impl(
179179
elif backend == "bkcl":
180180
pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id)
181181
elif backend == "flagcx":
182-
pg = core.ProcessGroupFlagcx.create(store, rank, world_size, group_id)
182+
pg = core.ProcessGroupFlagcx.create(store, rank, world_size, group_id, genv.pg_timeout,
183+
nccl_comm_init_option)
183184
return pg
184185

185186

0 commit comments

Comments
 (0)