@@ -187,13 +187,12 @@ phi::DeviceContext* ProcessGroupFlagcx::GetDeviceContext(
187
187
}
188
188
189
189
flagcxComm_t ProcessGroupFlagcx::FlagcxComm (const Place& place) const {
190
- PADDLE_ENFORCE_NOT_NULL (flagcx_comm_);
190
+ PADDLE_ENFORCE_NOT_NULL (flagcx_comm_, :: common::errors::InvalidArgument ( " flagcx_comm_ is nullptr " ) );
191
191
return flagcx_comm_;
192
192
}
193
193
194
194
phi::distributed::FlagcxCommContext* ProcessGroupFlagcx::GetOrCreateCommContext (
195
195
const Place& place, CommType comm_type) {
196
- VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::GetOrCreateCommContext" ;
197
196
const auto & key = GetKeyFromPlace (place);
198
197
std::string store_key;
199
198
GetStoreKey (key, comm_type, &store_key);
@@ -245,11 +244,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::AllReduce(
245
244
const AllreduceOptions& opts,
246
245
bool sync_op,
247
246
bool use_calc_stream) {
248
- VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::AllReduce" <<
249
- " sync_op: " << sync_op << " use_calc_stream: " << use_calc_stream;
250
247
CheckTensorContiguous (in_tensor);
251
248
CheckTensorContiguous (*out_tensor);
252
- VLOG (3 ) << " flagcx debug: finished checking input and output tensor" ;
253
249
254
250
return Collective (
255
251
[&](phi::distributed::FlagcxCommContext* comm_context, flagcxStream_t stream) {
@@ -693,7 +689,6 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
693
689
const std::string& store_key,
694
690
CommType comm_type,
695
691
int p2p_rank) {
696
- VLOG (3 ) << " flagcx debug: entered ProcessGroupFlagcx::CreateFlagcxEnvCache" ;
697
692
// TODO(changtao): we only support one flagcx comm ctx
698
693
if (flagcx_comm_ != nullptr ) {
699
694
return ;
@@ -703,22 +698,19 @@ void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place,
703
698
<< " , store_key: " << store_key;
704
699
store_key_ = store_key;
705
700
706
- VLOG (3 ) << " flagcx debug: before CreateFlagcxCommContext" ;
707
701
phi::distributed::CommContextManager::CreateFlagcxCommContext (
708
702
store_, store_key, rank_, size_, " " );
709
703
710
704
711
705
auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
712
706
VLOG (3 ) << " Get flagcx comm: " << flagcx_comm_ctx->GetFlagcxComm ();
713
- VLOG (3 ) << " flagcx debug: get flagcx comm" ;
714
707
flagcx_comm_ = flagcx_comm_ctx->GetFlagcxComm ();
715
708
auto comm_ctx = std::make_unique<phi::GPUContext>(place);
716
709
717
710
718
711
auto * calc_ctx = static_cast <phi::GPUContext*>(
719
712
phi::DeviceContextPool::Instance ().Get (place));
720
713
721
- VLOG (3 ) << " flagcx debug: adding key to maps" ;
722
714
place_to_calc_event_.emplace (
723
715
place_key,
724
716
platform::DeviceEvent (place, platform::GenerateDeviceEventFlag ()));
@@ -795,78 +787,61 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Collective(
795
787
CommType comm_type,
796
788
bool sync_op,
797
789
bool use_calc_stream) {
798
- VLOG (3 ) << " flagcx debug: Entered ProcessGroupFlagcx::Collective" ;
799
790
CheckTensorContiguous (tensor);
800
- VLOG (3 ) << " flagcx debug: finished checking tensor in Collective API" ;
801
791
802
792
comm_seq_++;
803
793
const auto & place = tensor.place ();
804
- VLOG (3 ) << " flagcx debug: getting key from place" ;
805
794
const auto & key = GetKeyFromPlace (place);
806
795
807
- VLOG (3 ) << " flagcx debug: adding cuda guard to device" ;
808
796
platform::CUDADeviceGuard cuda_guard (place);
809
797
810
798
std::string store_key;
811
- VLOG (3 ) << " flagcx debug: getting store key" ;
812
799
GetStoreKey (key, comm_type, &store_key);
813
800
814
801
if (place_to_comm_ctx_.find (key) == place_to_comm_ctx_.end ()) {
815
- VLOG (3 ) << " flagcx debug: creating flagcx env cache" ;
816
802
CreateFlagcxEnvCache (place, key, store_key, comm_type);
817
803
}
818
804
819
805
if (!use_calc_stream) {
820
- VLOG (3 ) << " flagcx debug: syncing calc stream" ;
821
806
SyncCalcStream (place, key);
822
807
}
823
808
824
809
auto task =
825
810
CreateTask (place, rank_, comm_type, sync_op, use_calc_stream, gid_);
826
811
827
- VLOG (3 ) << " flagcx debug: getting comm context" ;
828
812
const auto & comm_ctx = place_to_comm_ctx_.at (key);
829
- VLOG (3 ) << " flagcx debug: getting calc context" ;
830
813
const auto * calc_ctx = place_to_calc_ctx_.at (key);
831
814
832
- VLOG (3 ) << " flagcx debug: getting comm context" ;
833
815
auto flagcx_comm_ctx = this ->GetCommContext (&store_key);
834
816
835
817
flagcxStream_t flagcx_stream;
836
818
if (use_calc_stream) {
837
- VLOG (3 ) << " flagcx debug: getting calc stream" ;
838
819
auto calc_stream = calc_ctx->stream ();
839
820
flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamCopy (&flagcx_stream, (void *)&calc_stream);
840
821
} else {
841
- VLOG (3 ) << " flagcx debug: getting comm stream" ;
842
822
auto comm_stream = comm_ctx->stream ();
843
823
flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamCopy (&flagcx_stream, (void *)&comm_stream);
844
824
}
845
825
846
826
if (!FLAGS_enable_async_trace) {
847
- VLOG (3 ) << " flagcx debug: calling function" ;
848
827
fn (flagcx_comm_ctx, flagcx_stream);
849
828
}
850
829
851
830
if (!use_calc_stream) {
852
831
if (!is_coalescing_) {
853
- VLOG (3 ) << " flagcx debug: not coalescing, updating wait chain" ;
854
832
task->UpdateWaitChain (*comm_ctx);
855
833
allocation_stream_pairs_.emplace_back (tensor.Holder (), *(gpuStream_t*)flagcx_stream);
856
834
} else {
857
- VLOG (3 ) << " flagcx debug: coalescing tensors" ;
858
835
coalescing_tensors_.emplace_back (
859
836
std::make_shared<phi::DenseTensor>(tensor));
860
837
coalescing_place_keys_.push_back (key);
861
838
}
862
839
}
863
840
864
841
if (sync_op) {
865
- VLOG (3 ) << " flagcx debug: task wait" ;
866
842
task->Wait ();
867
843
}
868
844
869
- VLOG (3 ) << " flagcx debug: free flagcx tmp stream" ;
870
845
flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamFree (flagcx_stream);
871
846
872
847
return task;
@@ -927,11 +902,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupFlagcx::Point2Point(
927
902
928
903
flagcxStream_t flagcx_stream;
929
904
if (use_calc_stream) {
930
- VLOG (3 ) << " flagcx debug: getting calc stream" ;
931
905
auto calc_stream = calc_ctx->stream ();
932
906
flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamCopy (&flagcx_stream, (void *)&calc_stream);
933
907
} else {
934
- VLOG (3 ) << " flagcx debug: getting comm stream" ;
935
908
auto comm_stream = comm_ctx->stream ();
936
909
flagcx_comm_ctx->flagcx_handler_ ->devHandle ->streamCopy (&flagcx_stream, (void *)&comm_stream);
937
910
}
0 commit comments