@@ -222,6 +222,7 @@ flagcxComm_t ProcessGroupFlagcx::FlagcxComm(const Place& place) const {
222
222
223
223
phi::distributed::FlagcxCommContext* ProcessGroupFlagcx::GetOrCreateCommContext (
224
224
const Place& place, CommType comm_type) {
225
+ VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::GetOrCreateCommContext" ;
225
226
const auto & key = GetKeyFromPlace (place);
226
227
std::string store_key;
227
228
GetStoreKey (key, comm_type, &store_key);
@@ -273,8 +274,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::AllReduce(
273
274
const AllreduceOptions& opts,
274
275
bool sync_op,
275
276
bool use_calc_stream) {
277
+ VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::AllReduce" <<
278
+ " sync_op: " << sync_op << " use_calc_stream: " << use_calc_stream;
276
279
CheckTensorContiguous (in_tensor);
277
280
CheckTensorContiguous (*out_tensor);
281
+ VLOG (3 ) << " flagcx debug: finished checking input and output tensor" ;
278
282
279
283
return Collective (
280
284
[&](phi::distributed::FlagcxCommContext* comm_context, flagcxStream_t stream) {
@@ -749,13 +753,15 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
749
753
const std::string& store_key,
750
754
CommType comm_type,
751
755
int p2p_rank) {
756
+ VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::CreateFlagcxEnvCache" ;
752
757
// TODO(changtao): we only support one flagcx comm ctx
753
758
if (flagcx_comm_ != nullptr ) {
754
759
return ;
755
760
}
756
761
VLOG (3 ) << " init flagcx rank_in_group: " << rank_ << " , nranks: " << size_
757
762
<< " , gid: " << gid_ << " , place key: " << place_key
758
763
<< " , store_key: " << store_key;
764
+ store_key_ = store_key;
759
765
760
766
// for (size_t i = 0; i < s_group_call_counter; ++i) {
761
767
// NCCL_CHECK(phi::dynload::ncclGroupEnd());
@@ -770,6 +776,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
770
776
// NCCL_CHECK(phi::dynload::ncclGroupStart());
771
777
772
778
// phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank});
779
+ VLOG (3 ) << " flagcx debug: before CreateFlagcxCommContext" ;
773
780
phi::distributed::CommContextManager::CreateFlagcxCommContext (
774
781
store_, store_key, rank_, size_, " " );
775
782
@@ -779,7 +786,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
779
786
VLOG (3 ) << " Get flagcx comm: " << flagcx_comm_ctx->GetFlagcxComm ();
780
787
// << " for place_key: " << place_key << " on rank_in_group: " << rank
781
788
// << " nranks: " << num_ranks << " gid: " << gid_;
782
-
789
+ VLOG ( 3 ) << " flagcx debug: get flagcx comm " ;
783
790
flagcx_comm_ = flagcx_comm_ctx->GetFlagcxComm ();
784
791
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
785
792
// comm_ctx->set_nccl_comm(flagcx_comm_ctx->GetFlagcxComm());
@@ -831,6 +838,7 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
831
838
auto * calc_ctx = static_cast <phi::GPUContext*>(
832
839
phi::DeviceContextPool::Instance ().Get (place));
833
840
841
+ VLOG (3 ) << " flagcx debug: adding key to maps" ;
834
842
place_to_calc_event_.emplace (
835
843
place_key,
836
844
platform::DeviceEvent (place, platform::GenerateDeviceEventFlag ()));
@@ -910,45 +918,59 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
910
918
CommType comm_type,
911
919
bool sync_op,
912
920
bool use_calc_stream) {
921
+ VLOG (3 ) << " flagcx debug: Entered ProcessGroupFlagcx::Collective" ;
913
922
CheckTensorContiguous (tensor);
923
+ VLOG (3 ) << " flagcx debug: finished checking tensor in Collective API" ;
914
924
915
925
comm_seq_++;
916
926
const auto & place = tensor.place ();
927
+ VLOG (3 ) << " flagcx debug: getting key from place" ;
917
928
const auto & key = GetKeyFromPlace (place);
918
929
930
+ VLOG (3 ) << " flagcx debug: adding cuda guard to device" ;
919
931
platform::CUDADeviceGuard cuda_guard (place);
920
932
921
933
std::string store_key;
934
+ VLOG (3 ) << " flagcx debug: getting store key" ;
922
935
GetStoreKey (key, comm_type, &store_key);
923
936
924
937
if (place_to_comm_ctx_.find (key) == place_to_comm_ctx_.end ()) {
938
+ VLOG (3 ) << " flagcx debug: creating flagcx env cache" ;
925
939
CreateFlagcxEnvCache (place, key, store_key, comm_type);
926
940
}
927
941
928
942
if (!use_calc_stream) {
943
+ VLOG (3 ) << " flagcx debug: syncing calc stream" ;
929
944
SyncCalcStream (place, key);
930
945
}
931
946
932
947
auto task =
933
948
CreateTask (place, rank_, comm_type, sync_op, use_calc_stream, gid_);
934
949
935
- const auto * calc_ctx = place_to_calc_ctx_. at (key) ;
950
+ VLOG ( 3 ) << " flagcx debug: getting comm context " ;
936
951
const auto & comm_ctx = place_to_comm_ctx_.at (key);
952
+ VLOG (3 ) << " flagcx debug: getting calc context" ;
953
+ const auto * calc_ctx = place_to_calc_ctx_.at (key);
937
954
// auto nccl_comm = comm_ctx->nccl_comm();
938
955
// auto flagcx_stream = use_calc_stream ? (flagcxStream_t)&calc_ctx->stream() : (flagcxStream_t)&comm_ctx->stream();
956
+
957
+ VLOG (3 ) << " flagcx debug: getting comm context" ;
958
+ auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
959
+
939
960
flagcxStream_t flagcx_stream;
940
961
if (use_calc_stream) {
941
962
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
963
+ VLOG (3 ) << " flagcx debug: getting calc stream" ;
942
964
auto calc_stream = calc_ctx->stream ();
943
- flagcx_stream = (flagcxStream_t )&calc_stream;
965
+ flagcx_comm_ctx-> flagcx_handler_ -> devHandle -> streamCopy (& flagcx_stream, ( void * )&calc_stream) ;
944
966
} else {
967
+ VLOG (3 ) << " flagcx debug: getting comm stream" ;
945
968
auto comm_stream = comm_ctx->stream ();
946
- flagcx_stream = (flagcxStream_t )&comm_stream;
969
+ flagcx_comm_ctx-> flagcx_handler_ -> devHandle -> streamCopy (& flagcx_stream, ( void * )&comm_stream) ;
947
970
}
948
971
949
- auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
950
-
951
972
if (!FLAGS_enable_async_trace) {
973
+ VLOG (3 ) << " flagcx debug: calling function" ;
952
974
fn (flagcx_comm_ctx, flagcx_stream);
953
975
} else {
954
976
// std::string group_key = place_to_group_key_.at(key);
@@ -981,9 +1003,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
981
1003
// FLAGS_use_cuda_malloc_async_allocator) {
982
1004
// memory::RecordStream(tensor.Holder(), nccl_stream);
983
1005
// }
1006
+ VLOG (3 ) << " flagcx debug: not coalescing, updating wait chain" ;
984
1007
task->UpdateWaitChain (*comm_ctx);
985
- allocation_stream_pairs_.emplace_back (tensor.Holder (), flagcx_stream);
1008
+ allocation_stream_pairs_.emplace_back (tensor.Holder (), *(gpuStream_t*) flagcx_stream);
986
1009
} else {
1010
+ VLOG (3 ) << " flagcx debug: coalescing tensors" ;
987
1011
coalescing_tensors_.emplace_back (
988
1012
std::make_shared<phi::DenseTensor>(tensor));
989
1013
coalescing_place_keys_.push_back (key);
@@ -996,6 +1020,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
996
1020
// }
997
1021
998
1022
if (sync_op) {
1023
+ VLOG (3 ) << " flagcx debug: task wait" ;
999
1024
task->Wait ();
1000
1025
}
1001
1026
@@ -1006,6 +1031,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
1006
1031
// PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
1007
1032
// #endif
1008
1033
1034
+ VLOG (3 ) << " flagcx debug: free flagcx tmp stream" ;
1035
+ flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamFree (flagcx_stream);
1009
1036
1010
1037
return task;
1011
1038
}
@@ -1064,19 +1091,23 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
1064
1091
const auto * calc_ctx = place_to_calc_ctx_.at (key);
1065
1092
const auto & comm_ctx = place_to_comm_ctx_.at (key);
1066
1093
1094
+ auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
1095
+
1067
1096
// auto nccl_comm = comm_ctx->nccl_comm();
1068
1097
// auto flagcx_stream = use_calc_stream ? (flagcxStream_t)&calc_ctx->stream() : (flagcxStream_t)&comm_ctx->stream();
1069
1098
flagcxStream_t flagcx_stream;
1070
1099
if (use_calc_stream) {
1071
1100
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
1101
+ VLOG (3 ) << " flagcx debug: getting calc stream" ;
1072
1102
auto calc_stream = calc_ctx->stream ();
1073
- flagcx_stream = (flagcxStream_t )&calc_stream;
1103
+ flagcx_comm_ctx-> flagcx_handler_ -> devHandle -> streamCopy (& flagcx_stream, ( void * )&calc_stream) ;
1074
1104
} else {
1105
+ VLOG (3 ) << " flagcx debug: getting comm stream" ;
1075
1106
auto comm_stream = comm_ctx->stream ();
1076
- flagcx_stream = (flagcxStream_t )&comm_stream;
1107
+ flagcx_comm_ctx-> flagcx_handler_ -> devHandle -> streamCopy (& flagcx_stream, ( void * )&comm_stream) ;
1077
1108
}
1078
1109
1079
- std::string group_key = place_to_group_key_.at (key);
1110
+ // std::string group_key = place_to_group_key_.at(key);
1080
1111
// auto comm_task =
1081
1112
// std::make_shared<phi::distributed::FlagcxCommTask>(place,
1082
1113
// group_key,
@@ -1092,7 +1123,6 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
1092
1123
// comm_type,
1093
1124
// pg_timeout_);
1094
1125
1095
- auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
1096
1126
1097
1127
if (!FLAGS_enable_async_trace) {
1098
1128
fn (flagcx_comm_ctx, flagcx_stream, p2p_target_rank);
@@ -1113,7 +1143,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
1113
1143
// memory::RecordStream(tensor.Holder(), nccl_stream);
1114
1144
// }
1115
1145
task->UpdateWaitChain (*comm_ctx);
1116
- allocation_stream_pairs_.emplace_back (tensor.Holder (), flagcx_stream);
1146
+ allocation_stream_pairs_.emplace_back (tensor.Holder (), *(gpuStream_t*) flagcx_stream);
1117
1147
} else {
1118
1148
coalescing_tensors_.emplace_back (
1119
1149
std::make_shared<phi::DenseTensor>(tensor));
@@ -1137,7 +1167,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
1137
1167
// PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
1138
1168
// #endif
1139
1169
1140
-
1170
+ flagcx_comm_ctx-> flagcx_handler_ -> devHandle -> streamFree (flagcx_stream);
1141
1171
return task;
1142
1172
}
1143
1173
@@ -1205,17 +1235,20 @@ void ProcessGroupFlagcx::EndCoalescing(
1205
1235
const auto & tensor = coalescing_tensors_[i];
1206
1236
const auto & key = coalescing_place_keys_[i];
1207
1237
const auto & comm_ctx = place_to_comm_ctx_.at (key);
1238
+ auto flagcx_comm_ctx = this ->GetCommContext (&store_key_);
1208
1239
// Question: the returned stream type is essentially a cudaStream_t, can we cast it to flagcxStream_t?
1209
1240
auto comm_stream = comm_ctx->stream ();
1210
- auto flagcx_stream = (flagcxStream_t)&comm_stream;
1241
+ flagcxStream_t flagcx_stream;
1242
+ flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamCopy (&flagcx_stream, (void *)&comm_stream);
1211
1243
1212
1244
// if (FLAGS_use_stream_safe_cuda_allocator ||
1213
1245
// FLAGS_use_cuda_malloc_async_allocator) {
1214
1246
// memory::RecordStream(tensor->Holder(), nccl_stream);
1215
1247
// }
1216
1248
1217
1249
flagcx_task->UpdateWaitChain (*comm_ctx);
1218
- allocation_stream_pairs_.emplace_back (tensor->Holder (), flagcx_stream);
1250
+ allocation_stream_pairs_.emplace_back (tensor->Holder (), *(gpuStream_t*)flagcx_stream);
1251
+ flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamFree (flagcx_stream);
1219
1252
}
1220
1253
1221
1254
is_coalescing_ = false ;
0 commit comments