diff --git a/.gitmodules b/.gitmodules index 80ecb749c4db6..7d343bd148514 100644 --- a/.gitmodules +++ b/.gitmodules @@ -129,3 +129,7 @@ path = third_party/openvino url = https://github.com/openvinotoolkit/openvino.git ignore = dirty +[submodule "third_party/flagcx"] + path = third_party/flagcx + url = https://github.com/FlagOpen/FlagCX.git + ignore = dirty diff --git a/CMakeLists.txt b/CMakeLists.txt index a36c8183457a0..c799e33558d99 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -313,6 +313,7 @@ option( OFF) option(WITH_CINN "Compile PaddlePaddle with CINN" OFF) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) +option(WITH_FLAGCX "Compile PaddlePaddle with FLAGCX support" OFF) option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON) option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF) option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON) @@ -538,6 +539,11 @@ else() endif() endif() +if(WITH_FLAGCX) + add_definitions("-DPADDLE_WITH_FLAGCX") + # include(flagcx) +endif() + if(WITH_HETERPS AND WITH_PSLIB) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") endif() diff --git a/cmake/external/flagcx.cmake b/cmake/external/flagcx.cmake new file mode 100644 index 0000000000000..22f008d13fef6 --- /dev/null +++ b/cmake/external/flagcx.cmake @@ -0,0 +1,47 @@ +set(CMAKE_FIND_DEBUG_MODE ON) +# flagcx.cmake +if(NOT WITH_FLAGCX) + return() +endif() + +set(FLAGCX_SOURCE_DIR "${PADDLE_SOURCE_DIR}/third_party/flagcx") +set(FLAGCX_BINARY_DIR "${PADDLE_SOURCE_DIR}/build/third_party/flagcx") +set(THIRD_PARTY_DIR "${PADDLE_SOURCE_DIR}/build/third_party") +set(FLAGCX_ROOT "/usr/local/flagcx") +set(FLAGCX_LIB_DIR "${FLAGCX_BINARY_DIR}/build/lib") +set(USR_LOCAL_DIR "/usr/local") + +file(REMOVE_RECURSE ${FLAGCX_BINARY_DIR}) +message(STATUS "removed old flagcx dir") +message(STATUS "Copying third-party source to build directory") +execute_process(COMMAND cp -r ${FLAGCX_SOURCE_DIR} ${THIRD_PARTY_DIR} + RESULT_VARIABLE COPY_RESULT) + +if(NOT COPY_RESULT EQUAL 0) + message(FATAL_ERROR "Failed to copy third-party source to build directory") +endif() + +# Create a custom target to build the third-party library +message(STATUS "Building third-party library with its Makefile") +execute_process( + COMMAND make + WORKING_DIRECTORY ${FLAGCX_BINARY_DIR} + RESULT_VARIABLE BUILD_RESULT) + +find_path( + FLAGCX_INCLUDE_DIR flagcx.h + PATHS ${FLAGCX_SOURCE_DIR}/flagcx/include + NO_DEFAULT_PATH) + +message(STATUS "FLAGCX_INCLUDE_DIR is ${FLAGCX_INCLUDE_DIR}") +include_directories(SYSTEM ${FLAGCX_INCLUDE_DIR}) + +add_library(flagcx INTERFACE) +find_library( + FLAGCX_LIB + NAMES flagcx libflagcx + PATHS ${FLAGCX_LIB_DIR} + DOC "My custom library") + +add_dependencies(flagcx FLAGCX_LIB) +message(STATUS "FLAGCX_LIB is ${FLAGCX_LIB}") diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index e965c091b5d51..441fc489feb58 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -469,6 +469,10 @@ if(WITH_TESTING OR WITH_DISTRIBUTE) include(external/gtest) # download, build, install gtest list(APPEND third_party_deps extern_gtest) endif() +if(WITH_FLAGCX) + include(external/flagcx) + list(APPEND third_party_deps flagcx) +endif() if(WITH_ONNXRUNTIME) include(external/onnxruntime diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index b486008e206ca..294e68e381d7b 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1265,6 +1265,19 @@ PHI_DEFINE_EXPORTED_bool(multi_node_sample_use_gpu_table, PHI_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); #endif +/** + * ProcessGroupFlagCX related FLAG + * Name: flagcx_blocking_wait + * Since Version: + * Value Range: bool, default=false + * Example: + * Note: nccl blocking wait. + * blocks host thread until collective operation completes + */ +#if defined(PADDLE_WITH_FLAGCX) +PHI_DEFINE_EXPORTED_bool(flagcx_blocking_wait, false, "flagcx blocking wait"); +#endif + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PHI_DEFINE_EXPORTED_bool(benchmark_nccl, false, @@ -1770,6 +1783,13 @@ PHI_DEFINE_EXPORTED_string( "For instance, /usr/local/cuda/lib64. If default, " "dlopen will search cuda from LD_LIBRARY_PATH"); +PHI_DEFINE_EXPORTED_string( + flagcx_dir, // NOLINT + "/usr/local/flagcx/build/lib", + "Specify path for loading libflagcx.so. For instance, " + "For instance, /usr/local/flagcx/lib. If default, " + "dlopen will search flagcx from LD_LIBRARY_PATH"); + PHI_DEFINE_EXPORTED_string(cupti_dir, "", "Specify path for loading cupti.so."); // NOLINT diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 1febfa1c40d79..a9763dc44b79a 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -36,6 +36,13 @@ if(WITH_NCCL OR WITH_RCCL) endif() +if(WITH_FLAGCX) + cc_library( + process_group_flagcx + SRCS process_group_flagcx.cc common.cc + DEPS process_group phi) +endif() + if(WITH_XPU_BKCL) cc_library( process_group_bkcl diff --git a/paddle/fluid/distributed/collective/process_group_flagcx.cc b/paddle/fluid/distributed/collective/process_group_flagcx.cc new file mode 100644 index 0000000000000..e8cd11bb4ab68 --- /dev/null +++ b/paddle/fluid/distributed/collective/process_group_flagcx.cc @@ -0,0 +1,1126 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/process_group_flagcx.h" +#include "paddle/common/flags.h" +#include "paddle/fluid/distributed/collective/common.h" +#include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" +#include "paddle/phi/core/distributed/comm_task_manager.h" +#include "paddle/phi/core/distributed/flagcx_tools.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/platform/cuda_device_guard.h" +#include "paddle/phi/core/platform/device/gpu/nccl_helper.h" +#include "paddle/phi/core/utils/data_type.h" + +COMMON_DECLARE_bool(flagcx_blocking_wait); +COMMON_DECLARE_bool(enable_async_trace); +COMMON_DECLARE_bool(eager_communication_connection); + +// set this flag to `true` and recompile to enable dynamic checks +// constexpr bool FLAGS_enable_nccl_dynamic_check = false; +constexpr int64_t kWaitBlockTImeout = 10; + +namespace paddle::distributed { + +using phi::distributed::CheckSizeOnEachRank; +using phi::distributed::FlagcxDTypeToString; +using phi::distributed::FlagcxRedTypeToString; +using phi::distributed::IsP2POP; +using phi::distributed::SerializeFlagcxUniqueId; +using phi::distributed::ToFlagcxRedType; + +uint64_t ProcessGroupFlagcx::s_group_call_counter = 0; + +ProcessGroupFlagcx::FlagcxTask::FlagcxTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream, + int gid) + : TaskStream(rank, comm_type, sync_op, use_calc_stream), + task_place_(place), + gid_(gid) { + if (!use_calc_stream) { + comm_event_ = std::make_shared( + place, platform::GenerateDeviceEventFlag()); + } +} + +ProcessGroupFlagcx::FlagcxTask::~FlagcxTask() = default; + +bool ProcessGroupFlagcx::FlagcxTask::IsCompleted() { + if (comm_event_) { + return comm_event_->Query(); + } else { + return true; + } +} + +void ProcessGroupFlagcx::FlagcxTask::UpdateWaitChain( + const phi::DeviceContext& ctx) { + if (comm_event_) { + comm_event_->Record(&ctx); + } +} + +void ProcessGroupFlagcx::FlagcxTask::RemoveHolderStreamInGroup() { + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(gid_); + if (!pg) return; + auto* pg_flagcx = dynamic_cast(pg); + if (!pg_flagcx) return; + pg_flagcx->EraseTensorHolders(); +} + +// TODO(sheniang03): Add timeout for wait, now timeout unused +bool ProcessGroupFlagcx::FlagcxTask::Wait(std::chrono::milliseconds timeout) { + // Warning here when use calc stream but also invoke waiting explicitly. + if (UseCalcStream()) { + VLOG(5) << "Warning: The communication is on calc stream, wait here is " + "useless."; + return true; + } + + const auto* calc_ctx = + platform::DeviceContextPool::Instance().Get(task_place_); + if (comm_event_) { + comm_event_->Wait(platform::Place2DeviceType(task_place_), calc_ctx); + } + + if (FLAGS_flagcx_blocking_wait) { + // NOTE(shenliang03): It will block host for sync + while (!IsCompleted()) { + std::this_thread::sleep_for(std::chrono::milliseconds(kWaitBlockTImeout)); + } + } + RemoveHolderStreamInGroup(); + return true; +} + +// Same as Wait +void ProcessGroupFlagcx::FlagcxTask::Synchronize() { Wait(kWaitTimeout); } + +ProcessGroupFlagcx::ProcessGroupFlagcx( + const std::shared_ptr& store, + int rank, + int size, + int gid, + int64_t timeout, + int flagcx_comm_init_option) + : ProcessGroupWithStream(rank, size, gid), + store_(store), + place_to_calc_event_(), + place_to_calc_ctx_(), + place_to_comm_ctx_(), + p2p_comm_seq_(), + place_to_group_key_(), + pg_timeout_(timeout), + flagcx_comm_init_option_(flagcx_comm_init_option), + allocation_stream_pairs_() { + LOG(INFO) << "ProcessGroupFlagcx pg_timeout_ " << pg_timeout_; + LOG(INFO) << "ProcessGroupFlagcx flagcx_comm_init_option_ " + << flagcx_comm_init_option_; + if (FLAGS_eager_communication_connection) { + EagerConnect(); + } +} +ProcessGroupFlagcx::~ProcessGroupFlagcx() { + LOG(INFO) << "ProcessGroupFlagcx destruct "; +} + +void ProcessGroupFlagcx::GroupStart() { + if (flagcx_comm_ != nullptr) { + FLAGCX_CHECK(phi::dynload::flagcxGroupStart(flagcx_comm_)); + ++s_group_call_counter; + } +} + +void ProcessGroupFlagcx::GroupEnd() { + if (flagcx_comm_ != nullptr) { + FLAGCX_CHECK(phi::dynload::flagcxGroupEnd(flagcx_comm_)); + --s_group_call_counter; + } +} + +phi::DeviceContext* ProcessGroupFlagcx::GetDeviceContext( + const Place& place) const { + return GetDeviceContext(place, /*use_calc_stream*/ false); +} + +// NOTE(shenliang03): GetDeviceContext is only used for collective, it can't +// be used for p2p op. +phi::DeviceContext* ProcessGroupFlagcx::GetDeviceContext( + const Place& place, bool use_calc_stream) const { + const std::string& key = GetKeyFromPlace(place); + if (use_calc_stream) { + const auto& iter = place_to_calc_ctx_.find(key); + return iter->second; + } else { + const auto& iter = place_to_comm_ctx_.find(key); + PADDLE_ENFORCE_NE( + iter, + place_to_comm_ctx_.end(), + common::errors::NotFound( + "Cannot find the device context in this process group.")); + return iter->second.get(); + } +} + +flagcxComm_t ProcessGroupFlagcx::FlagcxComm(const Place& place) const { + PADDLE_ENFORCE_NOT_NULL( + flagcx_comm_, + ::common::errors::InvalidArgument("flagcx_comm_ is nullptr")); + return flagcx_comm_; +} + +phi::distributed::FlagcxCommContext* ProcessGroupFlagcx::GetOrCreateCommContext( + const Place& place, CommType comm_type) { + const auto& key = GetKeyFromPlace(place); + std::string store_key; + GetStoreKey(key, comm_type, &store_key); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateFlagcxEnvCache(place, key, store_key, comm_type); + } + return GetCommContext(&store_key); +} + +std::shared_ptr ProcessGroupFlagcx::AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + // numel > 0 indicates the tensor need to be sliced + const phi::DenseTensor& in_tensor_maybe_partial = + numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[flagcxAllGather] " + << "sendbuff: " << in_tensor_maybe_partial.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor_maybe_partial.numel() + << ", datatype: " + << FlagcxDTypeToString( + phi::ToFlagcxDataType(in_tensor_maybe_partial.dtype())) + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + comm_context->AllGather(out_tensor, in_tensor_maybe_partial, stream); + }, + in_tensor_maybe_partial, + CommType::ALLGATHER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[flagcxAllReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", redop: " + << FlagcxRedTypeToString(ToFlagcxRedType(opts.reduce_op)) + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + comm_context->AllReduce( + out_tensor, in_tensor, ToFlagcxRedType(opts.reduce_op), stream); + }, + in_tensor, + CommType::ALLREDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + std::vector out_split_sizes; + std::vector in_split_sizes; + if (out_size_each_rank.empty() && in_size_each_rank.empty()) { + out_split_sizes = + std::vector(size_, out_tensor->dims()[0] / size_); + in_split_sizes = std::vector(size_, in_tensor.dims()[0] / size_); + } else { + out_split_sizes = out_size_each_rank; + in_split_sizes = in_size_each_rank; + } + + const phi::DDim& out_dim = out_tensor->dims(); + const phi::DDim& in_dim = in_tensor.dims(); + // CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); + // CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); + CheckSizeOnEachRank(out_dim, out_split_sizes, size_); + CheckSizeOnEachRank(in_dim, in_split_sizes, size_); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + int64_t in_row_size = + in_dim[0] == 0 ? 0 : in_tensor.numel() / in_dim[0]; + int64_t out_row_size = + out_dim[0] == 0 ? 0 : out_tensor->numel() / out_dim[0]; + int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; + phi::DenseTensor input_partial, output_partial; + + VLOG(3) << "[AllToAll] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", out_split_sizes: " + << string::join_strings(out_split_sizes, ',') + << ", in_split_sizes: " + << string::join_strings(in_split_sizes, ',') + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + GroupStart(); + for (auto i = 0; i < size_; i++) { + in_numel = in_split_sizes[i] * in_row_size; + + if (in_numel > 0) { + input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); + comm_context->Send(input_partial, in_numel, i, stream); + } + in_offset += in_numel; + out_numel = out_split_sizes[i] * out_row_size; + if (out_numel > 0) { + output_partial = + GetPartialTensor(*out_tensor, out_offset, out_numel); + comm_context->Recv(&output_partial, out_numel, i, stream); + } + out_offset += out_numel; + } + GroupEnd(); + }, + in_tensor, + CommType::ALLTOALL, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::AllToAll( + std::vector* out_tensors, + const std::vector& in_tensors, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensors); + CheckTensorContiguous(*out_tensors); + CheckTensorSamePlace(in_tensors); + CheckTensorSamePlace(*out_tensors); + phi::distributed::CommStaticCheck::CheckDataType(*out_tensors, in_tensors); + + PADDLE_ENFORCE_EQ( + out_tensors->size(), + size_, + common::errors::InvalidArgument( + "Number of out tensors[%d] do not match the world size[%d].", + out_tensors->size(), + size_)); + PADDLE_ENFORCE_EQ( + in_tensors.size(), + size_, + common::errors::InvalidArgument( + "Number of in tensors[%d] do not match the world size[%d].", + in_tensors.size(), + size_)); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[AllToAll] " + << "sendbuff: " + << string::join_strings(GetTensorPtrs(in_tensors), ',') + << ", recvbuff: " + << string::join_strings(GetTensorPtrs(*out_tensors), ',') + << ", datatype: " + << FlagcxDTypeToString( + phi::ToFlagcxDataType(in_tensors[0].dtype())) + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", out_split_sizes: " + << string::join_strings(GetAllToAllSplitSizes(*out_tensors), + ',') + << ", in_split_sizes: " + << string::join_strings(GetAllToAllSplitSizes(in_tensors), ',') + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + GroupStart(); + for (auto i = 0; i < size_; i++) { + int64_t in_numel = in_tensors[i].numel(); + int64_t out_numel = (*out_tensors)[i].numel(); + + if (in_numel > 0) { + comm_context->Send(in_tensors[i], in_numel, i, stream); + } + + if (out_numel > 0) { + comm_context->Recv(&(*out_tensors)[i], out_numel, i, stream); + } + } + GroupEnd(); + }, + in_tensors, + CommType::ALLTOALL, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Barrier( + const BarrierOptions& opts) { + PADDLE_ENFORCE_GE(opts.device_id, + 0, + common::errors::PreconditionNotMet( + "The barrier device id must greater or equal than 0.")); + phi::GPUPlace place(opts.device_id); + auto allocator = std::unique_ptr( + new paddle::experimental::DefaultAllocator(place)); + phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); + phi::DenseTensor barrier_tensor{allocator.get(), meta}; + + VLOG(3) << "[Barrier] " + << "barrier opt: " << opts.device_id; + + auto task = AllReduce(&barrier_tensor, + barrier_tensor, + {}, + /*sync_op*/ true, + /*use_calc_stream*/ false); + auto flagcx_task = dynamic_cast(task.get()); + flagcx_task->SetBlockCPUInWait(); + return task; +} + +std::shared_ptr ProcessGroupFlagcx::Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + int root = opts.source_rank + opts.source_root; + + VLOG(3) << "[flagcxBroadcast] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", root: " << root + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + comm_context->Broadcast(out_tensor, in_tensor, root, stream); + }, + in_tensor, + CommType::BROADCAST, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Reduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[flagcxReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", redop: " + << FlagcxRedTypeToString(ToFlagcxRedType(opts.reduce_op)) + << ", root: " << opts.root_rank + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + comm_context->Reduce(out_tensor, + in_tensor, + ToFlagcxRedType(opts.reduce_op), + opts.root_rank, + stream); + }, + in_tensor, + CommType::REDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[flagcxReduceScatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", redop: " + << FlagcxRedTypeToString(ToFlagcxRedType(opts.reduce_op)) + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + comm_context->ReduceScatter( + out_tensor, in_tensor, ToFlagcxRedType(opts.reduce_op), stream); + }, + in_tensor, + CommType::REDUCE_SCATTER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Scatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + phi::distributed::CommStaticCheck::ScatterLikeShape( + *out_tensor, + in_tensor, + /*dst_rank*/ opts.root_rank, + /*cur_rank*/ rank_, + size_); + return Collective( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[Scatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + int64_t numel = in_tensor.numel() / size_; + if (rank_ == opts.root_rank) { + int64_t offset = 0; + phi::DenseTensor partial_tensor; + this->GroupStart(); + for (auto i = 0; i < size_; i++) { + partial_tensor = GetPartialTensor(in_tensor, offset, numel); + comm_context->Send(partial_tensor, numel, i, stream); + offset += numel; + } + comm_context->Recv(out_tensor, numel, opts.root_rank, stream); + this->GroupEnd(); + } else { + comm_context->Recv(out_tensor, numel, opts.root_rank, stream); + } + }, + in_tensor, + CommType::SCATTER, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Gather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*out_tensor); + + std::vector partial_tensors; + if (rank_ == opts.root_rank) { + partial_tensors.reserve(size_); + size_t offset = 0; + size_t numel = out_tensor->numel() / size_; + for (auto i = 0; i < size_; i++) { + partial_tensors.push_back(GetPartialTensor(*out_tensor, + static_cast(offset), + static_cast(numel))); + offset += numel; + } + } + return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(in_tensor); + CheckTensorContiguous(*gather_tensors_ptr); + + auto& gather_tensors = *gather_tensors_ptr; + PADDLE_ENFORCE_GT(size_, + opts.root_rank, + common::errors::InvalidArgument( + "root world size [%d] is less than root rank [%d]", + size_, + opts.root_rank)); + auto gather_func = [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream) { + VLOG(3) << "[Gather] " + << "sendbuff: " << in_tensor.data() + << ", count: " << in_tensor.numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << ", " << GetGroupMessage(); + + this->GroupStart(); + // root receive from all devices + if (rank_ == opts.root_rank) { + for (auto i = 0; i < size_; i++) { + auto& gather_tensor = gather_tensors[i]; + comm_context->Recv(&gather_tensor, gather_tensor.numel(), i, stream); + } + } + // send to root + comm_context->Send(in_tensor, in_tensor.numel(), opts.root_rank, stream); + this->GroupEnd(); + }; + return Collective( + gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Recv( + phi::DenseTensor* tensor, + int src_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(*tensor); + // numel > 0 indicates the tensor need to be sliced + phi::DenseTensor partial_tensor; + if (numel > 0) { + partial_tensor = GetPartialTensor(*tensor, offset, numel); + tensor = &partial_tensor; + } + + return Point2Point( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream, + int rank_in_group) { + VLOG(3) << "[flagcxRecv] " + << "recvbuff: " << tensor->data() + << ", count: " << tensor->numel() << ", datatype: " + << FlagcxDTypeToString(phi::ToFlagcxDataType(tensor->dtype())) + << ", src_in_group: " << src_rank + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + comm_context->Recv(tensor, tensor->numel(), rank_in_group, stream); + }, + src_rank, + *tensor, + CommType::RECV, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Send( + const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(tensor); + // numel > 0 indicates the tensor need to be sliced + const phi::DenseTensor& tensor_maybe_partial = + numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; + + return Point2Point( + [&](phi::distributed::FlagcxCommContext* comm_context, + flagcxStream_t stream, + int rank_in_group) { + VLOG(3) << "[flagcxSend] " + << "sendbuff: " << tensor_maybe_partial.data() + << ", count: " << tensor_maybe_partial.numel() << ", datatype: " + << FlagcxDTypeToString( + phi::ToFlagcxDataType(tensor_maybe_partial.dtype())) + << ", dst_in_group: " << dst_rank + << ", flagcxcomm: " << comm_context->GetFlagcxComm() + << ", stream: " << stream + << ", rank_in_group: " << rank_in_group << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream << ", " + << GetGroupMessage(); + + comm_context->Send(tensor_maybe_partial, + tensor_maybe_partial.numel(), + rank_in_group, + stream); + }, + dst_rank, + tensor_maybe_partial, + CommType::SEND, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::CreateTask( + const Place& place, + int rank, + CommType comm_type, + bool is_sync, + bool use_calc_stream, + int gid) { + return std::make_shared( + place, rank, comm_type, is_sync, use_calc_stream, gid); +} + +void ProcessGroupFlagcx::GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key) { + *store_key = "flagcx_ids/" + std::to_string(gid_) + "/0"; + + place_to_group_key_[place_key] = *store_key; +} + +void ProcessGroupFlagcx::CreateFlagcxEnvCache(const Place& place, + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank) { + // TODO(changtao): we only support one flagcx comm ctx + if (flagcx_comm_ != nullptr) { + return; + } + VLOG(3) << "init flagcx rank_in_group: " << rank_ << ", nranks: " << size_ + << ", gid: " << gid_ << ", place key: " << place_key + << ", store_key: " << store_key; + store_key_ = store_key; + + phi::distributed::CommContextManager::CreateFlagcxCommContext( + store_, store_key, rank_, size_, ""); + + auto flagcx_comm_ctx = this->GetCommContext(&store_key); + VLOG(3) << "Get flagcx comm: " << flagcx_comm_ctx->GetFlagcxComm(); + flagcx_comm_ = flagcx_comm_ctx->GetFlagcxComm(); + auto comm_ctx = std::make_unique(place); + + auto* calc_ctx = static_cast( + phi::DeviceContextPool::Instance().Get(place)); + + place_to_calc_event_.emplace( + place_key, + platform::DeviceEvent(place, platform::GenerateDeviceEventFlag())); + place_to_calc_ctx_.emplace(place_key, calc_ctx); + place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx)); +} + +void ProcessGroupFlagcx::SyncCalcStream(const Place& place, + const std::string& place_key) { + auto& calc_event = place_to_calc_event_.at(place_key); + const auto* calc_ctx = place_to_calc_ctx_.at(place_key); + const auto* comm_ctx = place_to_comm_ctx_.at(place_key).get(); + calc_event.Record(calc_ctx); + calc_event.Wait(platform::Place2DeviceType(place), comm_ctx); +} + +void ProcessGroupFlagcx::EagerConnect() { + const auto deviceId = phi::backends::gpu::GetCurrentDeviceId(); + const auto& place = phi::GPUPlace(deviceId); + const auto key = GetKeyFromPlace(place); + + platform::CUDADeviceGuard cuda_guard(place); + std::string store_key; + GetStoreKey(key, CommType::ALLREDUCE, &store_key); + + auto it = place_to_comm_ctx_.find(key); + if (it == place_to_comm_ctx_.end()) { + CreateFlagcxEnvCache(place, key, store_key, CommType::ALLREDUCE); + } +} + +void ProcessGroupFlagcx::EagerConnectRingExchange() { + std::vector> peers; + const auto& place = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()); + + for (int rank = 0; rank < size_; rank++) { + auto peer_rank = rank + 1 >= size_ ? 0 : rank + 1; + peers.push_back(std::make_pair(rank, peer_rank)); + } + + for (auto& peer : peers) { + int f_rank = peer.first; + int s_rank = peer.second; + + int peer_rank = 0; + int cur_rank = rank_; + if (rank_ == f_rank) { + peer_rank = s_rank; + } else if (rank_ == s_rank) { + peer_rank = f_rank; + } else { + continue; + } + + int low_rank = cur_rank < peer_rank ? cur_rank : peer_rank; + int high_rank = cur_rank < peer_rank ? peer_rank : cur_rank; + std::string key = + std::to_string(low_rank) + "->" + std::to_string(high_rank); + + auto p2p_rank = rank_ < peer_rank ? 0 : 1; + platform::CUDADeviceGuard cuda_guard(place); + std::string store_key; + GetStoreKey(key, CommType::SEND, &store_key); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateFlagcxEnvCache(place, key, store_key, CommType::SEND, p2p_rank); + } + } +} + +std::shared_ptr ProcessGroupFlagcx::Collective( + std::function + fn, + const std::vector& tensors, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(tensors); + + VLOG(3) << "flagcx debug: collective start"; + comm_seq_++; + PADDLE_ENFORCE_GT( + tensors.size(), + 0, + common::errors::InvalidArgument("Num of tensors must be greater than 0")); + const auto& place = tensors[0].place(); + const auto& key = GetKeyFromPlace(place); + + platform::CUDADeviceGuard cuda_guard(place); + + std::string store_key; + GetStoreKey(key, comm_type, &store_key); + + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateFlagcxEnvCache(place, key, store_key, comm_type); + } + + if (!use_calc_stream) { + SyncCalcStream(place, key); + } + + auto task = + CreateTask(place, rank_, comm_type, sync_op, use_calc_stream, gid_); + + const auto& comm_ctx = place_to_comm_ctx_.at(key); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + + auto flagcx_comm_ctx = this->GetCommContext(&store_key); + + flagcxStream_t flagcx_stream; + if (use_calc_stream) { + auto calc_stream = calc_ctx->stream(); + flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy( + &flagcx_stream, reinterpret_cast(&calc_stream)); + } else { + auto comm_stream = comm_ctx->stream(); + flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy( + &flagcx_stream, reinterpret_cast(&comm_stream)); + } + + if (!FLAGS_enable_async_trace) { + fn(flagcx_comm_ctx, flagcx_stream); + } + + if (!use_calc_stream) { + if (!is_coalescing_) { + task->UpdateWaitChain(*comm_ctx); + for (size_t i = 0; i < tensors.size(); ++i) { + allocation_stream_pairs_.emplace_back( + tensors[i].Holder(), + *reinterpret_cast(flagcx_stream)); + } + } else { + for (size_t i = 0; i < tensors.size(); ++i) { + coalescing_tensors_.emplace_back( + std::make_shared(tensors[i])); + } + coalescing_place_keys_.push_back(key); + } + } + + if (sync_op) { + task->Wait(); + } + + flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream); + + return task; +} + +std::shared_ptr ProcessGroupFlagcx::Collective( + std::function + fn, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + const std::vector tensors = {tensor}; + return Collective(fn, tensors, comm_type, sync_op, use_calc_stream); +} + +std::shared_ptr ProcessGroupFlagcx::Point2Point( + std::function< + void(phi::distributed::FlagcxCommContext*, flagcxStream_t, int)> fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream) { + CheckTensorContiguous(tensor); + + const auto& place = tensor.place(); + + int p2p_rank = 0; + int p2p_target_rank = 0; + bool is_batch_p2p = s_group_call_counter > 0; + std::string key = ""; + + if (is_batch_p2p) { + key = GetKeyFromPlace(place); + p2p_rank = rank_; + p2p_target_rank = peer; + } else { + int low_rank = rank_ < peer ? rank_ : peer; + int high_rank = rank_ < peer ? peer : rank_; + key = std::to_string(low_rank) + "->" + std::to_string(high_rank); + p2p_rank = rank_ < peer ? 0 : 1; + p2p_target_rank = 1 - p2p_rank; + } + + platform::CUDADeviceGuard cuda_guard(place); + + std::string store_key; + GetStoreKey(key, comm_type, &store_key); + + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateFlagcxEnvCache(place, key, store_key, comm_type, p2p_rank); + } + if (p2p_comm_seq_.find(key) == p2p_comm_seq_.end()) { + p2p_comm_seq_[key] = 0; + } + p2p_comm_seq_[key]++; + + if (!use_calc_stream) { + SyncCalcStream(place, key); + } + + auto task = + CreateTask(place, rank_, comm_type, sync_op, use_calc_stream, gid_); + const auto* calc_ctx = place_to_calc_ctx_.at(key); + const auto& comm_ctx = place_to_comm_ctx_.at(key); + + auto flagcx_comm_ctx = this->GetCommContext(&store_key); + + flagcxStream_t flagcx_stream; + if (use_calc_stream) { + auto calc_stream = calc_ctx->stream(); + flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy( + &flagcx_stream, reinterpret_cast(&calc_stream)); + } else { + auto comm_stream = comm_ctx->stream(); + flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy( + &flagcx_stream, reinterpret_cast(&comm_stream)); + } + + if (!FLAGS_enable_async_trace) { + fn(flagcx_comm_ctx, flagcx_stream, p2p_target_rank); + } + + if (!use_calc_stream) { + if (!is_coalescing_) { + task->UpdateWaitChain(*comm_ctx); + allocation_stream_pairs_.emplace_back( + tensor.Holder(), *reinterpret_cast(flagcx_stream)); + } else { + coalescing_tensors_.emplace_back( + std::make_shared(tensor)); + coalescing_place_keys_.push_back(key); + } + } + + if (sync_op) { + task->Wait(); + } + + flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream); + return task; +} + +std::shared_ptr +ProcessGroupFlagcx::CreateProcessGroupFlagcx( + const std::shared_ptr& store, + int rank, + int size, + int gid, + int64_t timeout, + int flagcx_comm_init_option) { + auto process_group = std::make_shared( + store, rank, size, gid, timeout, flagcx_comm_init_option); + ProcessGroupIdMap::GetInstance().emplace(gid, process_group); + return process_group; +} + +phi::distributed::FlagcxCommContext* ProcessGroupFlagcx::GetCommContext( + const std::string* key) { + std::string store_key = std::to_string(this->gid_); + if (key && !key->empty()) { + store_key = *key; + } + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + auto comm_context = static_cast( + comm_context_manager.Get(store_key)); + PADDLE_ENFORCE_NE( + comm_context, + nullptr, + common::errors::Unavailable("FlagcxCommContext is nullptr")); + return comm_context; +} + +void ProcessGroupFlagcx::StartCoalescing() { + PADDLE_ENFORCE_EQ(is_coalescing_, + false, + common::errors::PreconditionNotMet( + "Coalescing is on, please call EndCoalesce.")); + is_coalescing_ = true; + this->GroupStart(); +} + +void ProcessGroupFlagcx::EndCoalescing( + std::optional>> tasks_opt) { + this->GroupEnd(); + + // NOTE(shenliang03): If using calculate stream, no need to record stream and + // update task. + if (!tasks_opt.has_value() || coalescing_tensors_.empty()) { + is_coalescing_ = false; + return; + } + + auto& tasks = tasks_opt.value(); + + PADDLE_ENFORCE_EQ( + tasks.size(), + coalescing_tensors_.size(), + common::errors::PreconditionNotMet( + "Number of tasks[%d] do not match number of collectives[%d].", + tasks.size(), + coalescing_tensors_.size())); + + for (size_t i = 0; i < tasks.size(); ++i) { + auto* flagcx_task = + static_cast(tasks[i].get()); + const auto& tensor = coalescing_tensors_[i]; + const auto& key = coalescing_place_keys_[i]; + const auto& comm_ctx = place_to_comm_ctx_.at(key); + auto flagcx_comm_ctx = this->GetCommContext(&store_key_); + auto comm_stream = comm_ctx->stream(); + flagcxStream_t flagcx_stream; + flagcx_comm_ctx->flagcx_handler_->devHandle->streamCopy( + &flagcx_stream, reinterpret_cast(&comm_stream)); + + flagcx_task->UpdateWaitChain(*comm_ctx); + allocation_stream_pairs_.emplace_back( + tensor->Holder(), *reinterpret_cast(flagcx_stream)); + flagcx_comm_ctx->flagcx_handler_->devHandle->streamFree(flagcx_stream); + } + + is_coalescing_ = false; + coalescing_tensors_.clear(); + coalescing_place_keys_.clear(); +} +} // namespace paddle::distributed diff --git a/paddle/fluid/distributed/collective/process_group_flagcx.h b/paddle/fluid/distributed/collective/process_group_flagcx.h new file mode 100644 index 0000000000000..96ae9dd09391b --- /dev/null +++ b/paddle/fluid/distributed/collective/process_group_flagcx.h @@ -0,0 +1,302 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_with_stream.h" +#include "paddle/phi/backends/gpu/forwards.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/core/distributed/flagcx_comm_context.h" +#include "paddle/phi/core/distributed/store/store.h" +#include "paddle/phi/core/platform/device_event.h" + +namespace paddle { +namespace distributed { + +using Place = phi::Place; + +class ProcessGroupFlagcx final : public ProcessGroupWithStream { + public: + class FlagcxTask final : public ProcessGroupWithStream::TaskStream, + public std::enable_shared_from_this { + public: + FlagcxTask(const Place& place, + int rank, + CommType comm_type, + bool sync_op, + bool use_calc_stream, + int gid); + virtual ~FlagcxTask(); + + bool IsCompleted() override; + bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) override; + void Synchronize() override; + void UpdateWaitChain(const phi::DeviceContext& ctx) override; + + bool IsBlockCPUInWait() const { return block_cpu_in_wait_; } + void SetBlockCPUInWait() { block_cpu_in_wait_ = true; } + + // TODO(changtao): methods below will be removed later + FlagcxTask(const std::vector& places, + int rank, + CommType CommType, + const std::vector& inputs); + + void RemoveHolderStreamInGroup(); + + private: + bool block_cpu_in_wait_{false}; + std::shared_ptr comm_event_; // event on comm stream + Place task_place_; + int gid_; + }; + + public: + static std::shared_ptr CreateProcessGroupFlagcx( + const std::shared_ptr& store, + int rank, + int size, + int gid, + int64_t timeout, + int flagcx_comm_init_option); + + ProcessGroupFlagcx(const std::shared_ptr& store, + int rank, + int size, + int gid, + int64_t timeout = 30 * 60 * 1000, + int flagcx_comm_init_option = 0); + ~ProcessGroupFlagcx(); + + std::string GetBackendName() const override { return "FLAGCX"; } + + phi::DeviceContext* GetDeviceContext(const Place& place) const override; + + phi::DeviceContext* GetDeviceContext(const Place& place, + bool use_calc_stream) const override; + + std::shared_ptr AllGather( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr AllReduce( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const AllreduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr AllToAll( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& out_size_each_rank, + const std::vector& in_size_each_rank, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr AllToAll( + std::vector* out_tensors, + const std::vector& in_tensors, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Barrier( + const BarrierOptions& = BarrierOptions()) override; + + std::shared_ptr Broadcast( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr ReduceScatter( + phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Gather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Gather( + std::vector* gather_tensors_ptr, + const phi::DenseTensor& in_tensor, + const GatherOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Recv(phi::DenseTensor* tensor, + int src_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr Send(const phi::DenseTensor& tensor, + int dst_rank, + int64_t offset, + int64_t numel, + bool sync_op, + bool use_calc_stream) override; + + // Can't declare these two functions as static because we access non-static + // variable in these functions + void GroupStart(); + + void GroupEnd(); + + flagcxComm_t FlagcxComm(const Place& place) const; + + const bool GetFlagcxCommInitOption() { return flagcx_comm_init_option_; } + + phi::distributed::FlagcxCommContext* GetOrCreateCommContext( + const Place& place, CommType comm_type = CommType::UNKNOWN); + + private: + std::shared_ptr CreateTask( + const Place& place, + int rank, + CommType op_type, + bool sync_op, + bool use_calc_stream, + int gid); + + void GetStoreKey(const std::string& place_key, + CommType comm_type, + std::string* store_key); + + void CreateFlagcxEnvCache(const Place& place, + const std::string& place_key, + const std::string& store_key, + CommType comm_type, + int p2p_rank = 0); + + void SyncCalcStream(const Place& place, const std::string& place_key); + + std::shared_ptr Collective( + std::function + fn, + const std::vector& tensors, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Collective( + std::function + fn, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Point2Point( + std::function< + void(phi::distributed::FlagcxCommContext*, flagcxStream_t, int)> fn, + int peer, + const phi::DenseTensor& tensor, + CommType comm_type, + bool sync_op, + bool use_calc_stream); + + phi::distributed::FlagcxCommContext* GetCommContext( + const std::string* key = nullptr); + + void EraseTensorHolders() { + for (const auto& allocation_stream : allocation_stream_pairs_) { + auto holder_ptr = allocation_stream.first.lock(); + if (holder_ptr) { + auto stream = reinterpret_cast(allocation_stream.second); + memory::EraseStream(holder_ptr, *stream); + } + } + VLOG(5) << "After task wait/synchronize, total " + << allocation_stream_pairs_.size() + << " tensor(s) allocation stream have been removed."; + allocation_stream_pairs_.clear(); + } + + virtual void StartCoalescing(); + + virtual void EndCoalescing( + std::optional>> + tasks_opt = std::nullopt); + + void EagerConnect(); + + void EagerConnectRingExchange(); + + private: + std::shared_ptr store_; + + std::unordered_map + place_to_calc_event_; // event on calc stream + // TODO(changtao02): find a way to manage different context + std::unordered_map place_to_calc_ctx_; + std::unordered_map> + place_to_comm_ctx_; + + uint64_t comm_seq_{0}; + std::unordered_map p2p_comm_seq_; + std::unordered_map place_to_group_key_; + + // TODO(changtao): attrs below will be removed later + std::mutex mutex_; + static uint64_t s_group_call_counter; + // default 30 minutes + int64_t pg_timeout_; + int flagcx_comm_init_option_; + + // optimize memory for process_group + std::vector, gpuStream_t>> + allocation_stream_pairs_; + flagcxComm_t flagcx_comm_{nullptr}; + std::string store_key_; + + // For coalescing tensors processing (eg. batch_isend_irecv) + bool is_coalescing_{false}; + std::vector> coalescing_tensors_; + std::vector coalescing_place_keys_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 48c2c3e7b6c39..e4b131413eb6d 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -189,6 +189,9 @@ if(WITH_PYTHON) if(WITH_MPI) set(PYBIND_DEPS ${PYBIND_DEPS} process_group_mpi) endif() + if(WITH_FLAGCX) + set(PYBIND_DEPS ${PYBIND_DEPS} process_group_flagcx) + endif() if(WITH_CUSTOM_DEVICE) set(PYBIND_DEPS ${PYBIND_DEPS} process_group_custom) endif() diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index b0fb8624b6afa..20d16307bb65e 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -54,6 +54,10 @@ limitations under the License. */ #include "paddle/fluid/distributed/collective/xpu_async_load.h" #endif +#if defined(PADDLE_WITH_FLAGCX) +#include "paddle/fluid/distributed/collective/process_group_flagcx.h" +#endif + #include "paddle/phi/kernels/sync_batch_norm_kernel.h" namespace paddle::pybind { @@ -82,6 +86,10 @@ using GlooStore = paddle::distributed::ProcessGroupGloo::GlooStore; using GlooOptions = paddle::distributed::ProcessGroupGloo::GlooOptions; #endif +#if defined(PADDLE_WITH_FLAGCX) +using ProcessGroupFlagcx = paddle::distributed::ProcessGroupFlagcx; +#endif + static UNUSED void *use_ccl_comm_func = phi::detail::GetCCLComm(phi::CPUPlace()); @@ -1505,6 +1513,20 @@ void BindDistributed(py::module *m) { py::call_guard()); #endif +#if defined(PADDLE_WITH_FLAGCX) + py::class_>( + *m, "ProcessGroupFlagcx", ProcessGroup) + .def_static("create", + distributed::ProcessGroupFlagcx::CreateProcessGroupFlagcx, + py::arg("store"), + py::arg("rank"), + py::arg("world_size"), + py::arg("group_id") = 0, + py::arg("timeout") = 30 * 60 * 1000, + py::arg("nccl_comm_init_option") = 0, + py::call_guard()); +#endif + m->def( "eager_assign_group_by_size", [](py::handle py_tensors, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 0df2e34e880af..9c9def757d95d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -308,6 +308,14 @@ bool IsCompiledWithNCCL() { #endif } +bool IsCompiledWithFlagcx() { +#ifdef PADDLE_WITH_FLAGCX + return true; +#else + return false; +#endif +} + bool IsCompiledWithMPI() { #ifdef PADDLE_WITH_MPI return true; @@ -2553,6 +2561,7 @@ All parameter, weight, gradient are variables in Paddle. std::make_unique(); VLOG(4) << "Initialize tensor operants successfully"; }); + m.def("is_compiled_with_flagcx", IsCompiledWithFlagcx); m.def("is_compiled_with_avx", IsCompiledWithAVX); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cudnn_frontend", IsCompiledWithCudnnFrontend); diff --git a/paddle/phi/CMakeLists.txt b/paddle/phi/CMakeLists.txt index 3bcc3539abeac..1f16ea88da37b 100644 --- a/paddle/phi/CMakeLists.txt +++ b/paddle/phi/CMakeLists.txt @@ -87,6 +87,10 @@ if(WITH_GLOO) list(APPEND PHI_DEPS gloo) endif() +if(WITH_FLAGCX) + list(APPEND PHI_DEPS flagcx) +endif() + if(WITH_CUDNN_FRONTEND) list(APPEND PHI_DEPS cudnn-frontend) endif() diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index 0f71d1741ec1d..605bb21cee991 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -80,6 +80,10 @@ if(WITH_XPU) collect_srcs(backends_srcs SRCS xpti.cc) endif() +if(WITH_FLAGCX) + collect_srcs(backends_srcs SRCS flagcx.cc) +endif() + if(WITH_FLASHATTN) list(APPEND DYNLOAD_COMMON_SRCS flashattn.cc) endif() diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index fd1823fce5ada..b0340f7b46c35 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -68,6 +68,11 @@ PHI_DEFINE_string(rccl_dir, "dlopen will search rccl from LD_LIBRARY_PATH"); #endif +// use hardcoded path for now to ensure correctness +#ifdef PADDLE_WITH_FLAGCX +COMMON_DECLARE_string(flagcx_dir); +#endif + #ifdef PADDLE_WITH_XPU PD_DEFINE_string(xpti_dir, "", "Specify path for loading libxpti.so."); #endif @@ -777,6 +782,14 @@ void* GetNCCLDsoHandle() { #endif } +void* GetFLAGCXDsoHandle() { +#ifdef PADDLE_WITH_FLAGCX + return GetDsoHandleFromSearchPath(FLAGS_flagcx_dir, "libflagcx.so"); +#else + return nullptr; +#endif +} + void* GetTensorRtDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_tensorrt_dir, "libnvinfer.dylib"); diff --git a/paddle/phi/backends/dynload/dynamic_loader.h b/paddle/phi/backends/dynload/dynamic_loader.h index 68e23828265b0..10e286aaa64b4 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.h +++ b/paddle/phi/backends/dynload/dynamic_loader.h @@ -39,6 +39,7 @@ void* GetWarpRNNTDsoHandle(); void* GetFlashAttnDsoHandle(); void* GetFlashAttnV3DsoHandle(); void* GetNCCLDsoHandle(); +void* GetFLAGCXDsoHandle(); void* GetTensorRtDsoHandle(); void* GetMKLMLDsoHandle(); void* GetLAPACKDsoHandle(); diff --git a/paddle/phi/backends/dynload/flagcx.cc b/paddle/phi/backends/dynload/flagcx.cc new file mode 100644 index 0000000000000..71c2f7937511b --- /dev/null +++ b/paddle/phi/backends/dynload/flagcx.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/dynload/flagcx.h" + +namespace phi { +namespace dynload { + +std::once_flag flagcx_dso_flag; +void* flagcx_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +FLAGCX_RAND_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/flagcx.h b/paddle/phi/backends/dynload/flagcx.h new file mode 100644 index 0000000000000..55224fcaa3938 --- /dev/null +++ b/paddle/phi/backends/dynload/flagcx.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include + +#include // NOLINT + +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/common/port.h" + +namespace phi { +namespace dynload { + +extern std::once_flag flagcx_dso_flag; +extern void* flagcx_dso_handle; + +#define DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using flagcx_func = decltype(&::__name); \ + std::call_once(flagcx_dso_flag, []() { \ + flagcx_dso_handle = phi::dynload::GetFLAGCXDsoHandle(); \ + }); \ + static void* p_##__name = dlsym(flagcx_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern struct DynLoad__##__name __name + +#define FLAGCX_RAND_ROUTINE_EACH(__macro) \ + __macro(flagcxGetUniqueId); \ + __macro(flagcxCommInitRank); \ + __macro(flagcxGetVersion); \ + __macro(flagcxCommAbort); \ + __macro(flagcxCommDestroy); \ + __macro(flagcxCommCount); \ + __macro(flagcxCommUserRank); \ + __macro(flagcxAllReduce); \ + __macro(flagcxBroadcast); \ + __macro(flagcxAllGather); \ + __macro(flagcxGroupStart); \ + __macro(flagcxGroupEnd); \ + __macro(flagcxReduce); \ + __macro(flagcxReduceScatter); \ + __macro(flagcxCommGetAsyncError); \ + __macro(flagcxSend); \ + __macro(flagcxRecv); \ + __macro(flagcxHandleInit); \ + __macro(flagcxHandleFree); \ + __macro(flagcxGetErrorString); + +FLAGCX_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP) + +#undef DECLARE_DYNAMIC_LOAD_FLAGCX_WRAP + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 62a8034ebd173..9deea1b7414ae 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -23,4 +23,8 @@ if(WITH_XPU_BKCL) list(APPEND DISTRIBUTED_COMMON_SRCS bkcl_comm_context.cc) endif() +if(WITH_FLAGCX) + list(APPEND DISTRIBUTED_COMMON_SRCS flagcx_comm_context.cc flagcx_tools.cc) +endif() + collect_srcs(core_srcs SRCS ${DISTRIBUTED_COMMON_SRCS}) diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 4b3ebadfa4395..deeb17b96874f 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -44,6 +44,11 @@ #include "paddle/phi/core/distributed/xccl_comm_context.h" #endif +#if defined(PADDLE_WITH_FLAGCX) +#include "paddle/phi/core/distributed/flagcx_comm_context.h" +#include "paddle/phi/core/distributed/flagcx_tools.h" +#endif + namespace phi::distributed { int CommContextManager::device_id = -1; @@ -261,6 +266,51 @@ void CommContextManager::CreateBKCLCommContext( comm_context_manager.Emplace(unique_comm_key, std::move(bkcl_comm_context)); } #endif + +#if defined(PADDLE_WITH_FLAGCX) +void CommContextManager::CreateFlagcxCommContext( + const std::shared_ptr& store, + const std::string& unique_comm_key, + int rank, + int size, + const std::string& hash_key) { + auto& comm_context_manager = CommContextManager::GetInstance(); + if (comm_context_manager.Has(unique_comm_key)) { + return; + } + flagcxHandlerGroup_t flagcx_handler; + phi::dynload::flagcxHandleInit(&flagcx_handler); + if (rank == 0) { + phi::dynload::flagcxGetUniqueId(&flagcx_handler->uniqueId); + } + + std::string unique_key = "FlagcxCommContext/" + unique_comm_key + hash_key; + if (rank == 0) { + std::vector flagcx_id_wrapper( + reinterpret_cast(flagcx_handler->uniqueId), + reinterpret_cast(flagcx_handler->uniqueId) + + sizeof(flagcxUniqueId)); + store->set(unique_key, flagcx_id_wrapper); + } else { + const auto& flagcx_id_wrapper = store->get(unique_key); + std::memcpy(reinterpret_cast(flagcx_handler->uniqueId), + flagcx_id_wrapper.data(), + flagcx_id_wrapper.size()); + } + + VLOG(3) << "init FlagcxCommContext rank: " << rank << ", size: " << size + << ", unique_comm_key: " << unique_comm_key + << ", unique_key: " << unique_key << ", flagcx_id: " + << SerializeFlagcxUniqueId(*flagcx_handler->uniqueId); + auto flagcx_comm_context = + std::make_unique(rank, size, flagcx_handler); + // TODO(changtao): find a way to manage different device context, + // now we use cuda device context as default + comm_context_manager.SetStore(store); + comm_context_manager.Emplace(unique_comm_key, std::move(flagcx_comm_context)); +} +#endif + CommContext* CommContextManager::Emplace( const std::string& unique_comm_key, std::unique_ptr comm_context) { diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 9e0cb8e5ec3d7..8b08508262806 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -28,6 +28,10 @@ #include "paddle/phi/backends/gpu/forwards.h" #endif +#if defined(PADDLE_WITH_FLAGCX) +#include +#endif + namespace phi { namespace distributed { @@ -105,6 +109,14 @@ class CommContextManager { const std::string& hash_key = ""); #endif +#if defined(PADDLE_WITH_FLAGCX) + static void CreateFlagcxCommContext(const std::shared_ptr& store, + const std::string& unique_comm_key, + int rank, + int size, + const std::string& hash_key = ""); +#endif + private: DISABLE_COPY_AND_ASSIGN(CommContextManager); diff --git a/paddle/phi/core/distributed/flagcx_comm_context.cc b/paddle/phi/core/distributed/flagcx_comm_context.cc new file mode 100644 index 0000000000000..6aca7ec98f252 --- /dev/null +++ b/paddle/phi/core/distributed/flagcx_comm_context.cc @@ -0,0 +1,178 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/flagcx_comm_context.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/dense_tensor.h" +// #include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" +#include "paddle/phi/core/distributed/check/static_check.h" +#include "paddle/phi/core/distributed/flagcx_tools.h" +#include "paddle/phi/core/distributed/utils.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi::distributed { + +// set this flag to `true` and recompile to enable dynamic checks +// constexpr bool FLAGS_enable_flagcx_dynamic_check = false; + +FlagcxCommContext::FlagcxCommContext(int rank, + int size, + flagcxHandlerGroup_t flagcx_handler) + : CommContext(rank, size), + flagcx_version_(0), + flagcx_handler_(flagcx_handler) { + phi::dynload::flagcxCommInitRank( + &flagcx_handler_->comm, size_, flagcx_handler_->uniqueId, rank_), + phi::dynload::flagcxGetVersion(&flagcx_version_); +} + +int FlagcxCommContext::GetFlagcxVersion() { return flagcx_version_; } + +flagcxComm_t FlagcxCommContext::GetFlagcxComm() { + return flagcx_handler_->comm; +} + +void FlagcxCommContext::Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + flagcxStream_t stream) { + CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); + FLAGCX_CHECK(phi::dynload::flagcxBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToFlagcxDataType(in_tensor.type()), + root, + flagcx_handler_->comm, + stream)); +} + +void FlagcxCommContext::AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); + FLAGCX_CHECK(phi::dynload::flagcxAllGather(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToFlagcxDataType(in_tensor.type()), + flagcx_handler_->comm, + stream)); +} +void FlagcxCommContext::ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); + FLAGCX_CHECK( + phi::dynload::flagcxReduceScatter(in_tensor.data(), + out_tensor->data(), + out_tensor->numel(), + ToFlagcxDataType(in_tensor.type()), + reduce_type, + flagcx_handler_->comm, + stream)); +} + +void FlagcxCommContext::Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::CheckShape(in_tensor, rank_, size_); + + FLAGCX_CHECK(phi::dynload::flagcxSend(in_tensor.data(), + count, + ToFlagcxDataType(in_tensor.dtype()), + peer, + flagcx_handler_->comm, + stream)); + VLOG(3) << "rank " << GetRank() << " send " + << common::product(in_tensor.dims()) << " to " << peer; +} + +void FlagcxCommContext::Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::CheckShape(*out_tensor, rank_, size_); + + FLAGCX_CHECK(phi::dynload::flagcxRecv(out_tensor->data(), + count, + ToFlagcxDataType(out_tensor->dtype()), + peer, + flagcx_handler_->comm, + stream)); + VLOG(3) << "rank " << GetRank() << " recv " + << common::product(out_tensor->dims()) << " from " << peer; +} + +void FlagcxCommContext::AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); + FLAGCX_CHECK(phi::dynload::flagcxAllReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToFlagcxDataType(in_tensor.type()), + reduce_type, + flagcx_handler_->comm, + stream)); +} + +void FlagcxCommContext::Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + int root, + flagcxStream_t stream) { + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ root, + /*cur_rank*/ rank_, + size_); + FLAGCX_CHECK(phi::dynload::flagcxReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToFlagcxDataType(in_tensor.type()), + reduce_type, + root, + flagcx_handler_->comm, + stream)); +} + +void FlagcxCommContext::GroupStart() { + FLAGCX_CHECK(phi::dynload::flagcxGroupStart(flagcx_handler_->comm)); +} +void FlagcxCommContext::GroupEnd() { + FLAGCX_CHECK(phi::dynload::flagcxGroupEnd(flagcx_handler_->comm)); +} + +} // namespace phi::distributed diff --git a/paddle/phi/core/distributed/flagcx_comm_context.h b/paddle/phi/core/distributed/flagcx_comm_context.h new file mode 100644 index 0000000000000..9453788d971b1 --- /dev/null +++ b/paddle/phi/core/distributed/flagcx_comm_context.h @@ -0,0 +1,92 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/common/macros.h" +#include "paddle/phi/backends/dynload/flagcx.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_decls.h" +#include "paddle/phi/core/distributed/comm_context.h" + +namespace phi { +class DenseTensor; +namespace distributed { + +class FlagcxCommContext final : public CommContext { + public: + FlagcxCommContext(int rank, int size, flagcxHandlerGroup_t flagcx_handler); + ~FlagcxCommContext() override = default; + + int GetFlagcxVersion(); + + flagcxComm_t GetFlagcxComm(); + + void Broadcast(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int root, + flagcxStream_t stream); + + void Send(const phi::DenseTensor& in_tensor, + const int64_t& count, + const int& peer, + flagcxStream_t stream); + + void Recv(phi::DenseTensor* out_tensor, + const int64_t& count, + const int& peer, + flagcxStream_t stream); + + void ReduceScatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + flagcxStream_t stream); + + void AllGather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxStream_t stream); + + void AllReduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + flagcxStream_t stream); + + void Reduce(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + flagcxRedOp_t reduce_type, + int root, + flagcxStream_t stream); + + void GroupStart(); + + void GroupEnd(); + + private: + DISABLE_COPY_AND_ASSIGN(FlagcxCommContext); + + int flagcx_version_; + + std::unique_ptr dev_ctx_; + + // used for comm wait compute, compute_stream-->event-->comm_stream + std::shared_ptr::type> compute_event_; + + // used for compute wait comm, comm_stream-->event-->compute_stream + std::shared_ptr::type> comm_event_; + + public: + flagcxHandlerGroup_t flagcx_handler_; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/flagcx_tools.cc b/paddle/phi/core/distributed/flagcx_tools.cc new file mode 100644 index 0000000000000..5095de2fe3e3f --- /dev/null +++ b/paddle/phi/core/distributed/flagcx_tools.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/flagcx_tools.h" + +#include + +#include "paddle/common/errors.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace distributed { + +flagcxRedOp_t ToFlagcxRedType(ReduceOp reduction) { + static const std::unordered_map red_type = { + {ReduceOp::MIN, flagcxMin}, + {ReduceOp::MAX, flagcxMax}, + {ReduceOp::SUM, flagcxSum}, + {ReduceOp::PRODUCT, flagcxProd}, + {ReduceOp::AVG, flagcxAvg}, + }; + auto it = red_type.find(reduction); + PADDLE_ENFORCE_EQ( + it != red_type.end(), + true, + common::errors::InvalidArgument( + "Invalid flagcx reduction. Must be flagcxMin | flagcxMax | " + "flagcxProd | flagcxSum | flagcxAvg.")); + return it->second; +} + +std::string SerializeFlagcxUniqueId(const flagcxUniqueId& flagcxID) { + const uint8_t* bytes = reinterpret_cast(&flagcxID); + std::ostringstream oss; + for (auto i = 0; i < FLAGCX_UNIQUE_ID_BYTES; ++i) { + oss << std::hex << static_cast(bytes[i]); + } + return oss.str(); +} + +std::string FlagcxDTypeToString(flagcxDataType_t dtype) { +#define PD_FLAGCX_DTYPE_TO_STR(__flagcx_dtype, __str_dtype) \ + if (dtype == __flagcx_dtype) return __str_dtype; + PD_FLAGCX_DTYPE_TO_STR(flagcxFloat, "float32"); + PD_FLAGCX_DTYPE_TO_STR(flagcxFloat32, "float32"); + PD_FLAGCX_DTYPE_TO_STR(flagcxHalf, "float16"); + PD_FLAGCX_DTYPE_TO_STR(flagcxFloat16, "float16"); + PD_FLAGCX_DTYPE_TO_STR(flagcxBfloat16, "bfloat16"); + PD_FLAGCX_DTYPE_TO_STR(flagcxDouble, "float64"); + PD_FLAGCX_DTYPE_TO_STR(flagcxFloat64, "float64"); + PD_FLAGCX_DTYPE_TO_STR(flagcxInt8, "int8"); + PD_FLAGCX_DTYPE_TO_STR(flagcxChar, "int8"); + PD_FLAGCX_DTYPE_TO_STR(flagcxUint8, "uint8"); + PD_FLAGCX_DTYPE_TO_STR(flagcxInt32, "int32"); + PD_FLAGCX_DTYPE_TO_STR(flagcxInt, "int32"); + PD_FLAGCX_DTYPE_TO_STR(flagcxUint32, "uint32"); + PD_FLAGCX_DTYPE_TO_STR(flagcxInt64, "int64"); + PD_FLAGCX_DTYPE_TO_STR(flagcxUint64, "uint64"); + +#undef PD_FLAGCX_DTYPE_TO_STR + PADDLE_THROW(common::errors::InvalidArgument( + "This datatype %d in flagcx is not supported.", static_cast(dtype))); +} + +std::string FlagcxRedTypeToString(flagcxRedOp_t op) { + if (op == flagcxSum) return "SUM"; + if (op == flagcxProd) return "PROD"; + if (op == flagcxMin) return "MIN"; + if (op == flagcxMax) return "MAX"; + if (op == flagcxAvg) return "AVG"; + return "UDF_" + std::to_string(op); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/flagcx_tools.h b/paddle/phi/core/distributed/flagcx_tools.h new file mode 100644 index 0000000000000..95cbd5ee1529d --- /dev/null +++ b/paddle/phi/core/distributed/flagcx_tools.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/phi/core/distributed/types.h" + +namespace phi { +namespace distributed { + +#define FLAGCX_CHECK(cmd) \ + do { \ + flagcxResult_t r = cmd; \ + if (r != flagcxSuccess) { \ + PADDLE_THROW( \ + common::errors::External("Failed, FlagCX error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + phi::dynload::flagcxGetErrorString(r))); \ + } \ + } while (0) + +flagcxRedOp_t ToFlagcxRedType(ReduceOp reduction); + +std::string SerializeFlagcxUniqueId(const flagcxUniqueId& flagcxID); + +std::string FlagcxDTypeToString(flagcxDataType_t dtype); + +std::string FlagcxRedTypeToString(flagcxRedOp_t op); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h index ef8f922ffdede..8d5ac35c22135 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h @@ -55,6 +55,10 @@ limitations under the License. */ #endif // __APPLE__ #endif // PADDLE_WITH_CUDA +#ifdef PADDLE_WITH_FLAGCX +#include +#endif + #ifdef PADDLE_WITH_HIP #include "paddle/phi/backends/dynload/hipblasLt.h" #include "paddle/phi/backends/dynload/hipfft.h" diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index 7713af7548527..269da95e7fda1 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -278,5 +278,31 @@ inline BKCLDataType ToBKCLDataType(DataType type) { } } #endif +#if defined(PADDLE_WITH_FLAGCX) +inline flagcxDataType_t ToFlagcxDataType(DataType type) { + if (type == DataType::FLOAT32) { + return flagcxFloat; + } else if (type == DataType::FLOAT64) { + return flagcxDouble; + } else if (type == DataType::INT32) { + return flagcxInt; + } else if (type == DataType::INT64) { + return flagcxInt64; + } else if (type == DataType::FLOAT16) { + return flagcxFloat16; + } else if (type == DataType::UINT8) { + return flagcxUint8; + } else if (type == DataType::INT8) { + return flagcxInt8; + } else if (type == DataType::BOOL) { + return flagcxUint8; + } else if (type == DataType::BFLOAT16) { + return flagcxBfloat16; + } else { + PADDLE_THROW( + errors::Unimplemented("This datatype in flagcx is not supported.")); + } +} +#endif } // namespace phi diff --git a/python/env_dict.py.in b/python/env_dict.py.in index a90679758ce82..5e746b1d7072b 100644 --- a/python/env_dict.py.in +++ b/python/env_dict.py.in @@ -110,4 +110,6 @@ env_dict={ 'WITH_TENSORRT':'@WITH_TENSORRT@', 'TR_INFER_RT':'@TR_INFER_RT@', 'TENSORRT_LIBRARY_DIR':'@TENSORRT_LIBRARY_DIR@', + 'WITH_FLAGCX':'@WITH_FLAGCX@', + 'FLAGCX_LIB':'@FLAGCX_LIB@', } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index da0090a00bc00..fae9f25e2a8db 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -652,6 +652,9 @@ nccl_lib_path = package_dir + "/.." + "/nvidia/nccl/lib" set_flags({"FLAGS_nccl_dir": nccl_lib_path}) + # flagcx_lib_path = os.getenv('FLAGCX_ROOT', '') + "/build/lib" + # set_flags({"FLAGS_flagcx_dir": flagcx_lib_path}) + cupti_dir_lib_path = package_dir + "/.." + "/nvidia/cuda_cupti/lib" set_flags({"FLAGS_cupti_dir": cupti_dir_lib_path}) diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 0d594e746697f..895b452f62739 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -75,7 +75,7 @@ def _get_global_env(): # Name of the default group for init_parallel_env _default_group_name = "_default_pg" -_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl'] +_valid_backend_list = ['nccl', 'gloo', 'heter', 'xccl', 'bkcl', 'flagcx'] _default_store = None # the default tcp store _default_backend = None _default_timeout = datetime.timedelta(seconds=1800) @@ -178,6 +178,15 @@ def _new_process_group_impl( ) elif backend == "bkcl": pg = core.ProcessGroupBKCL.create(store, rank, world_size, group_id) + elif backend == "flagcx": + pg = core.ProcessGroupFlagcx.create( + store, + rank, + world_size, + group_id, + genv.pg_timeout, + nccl_comm_init_option, + ) return pg diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 1e62a248aa769..52739e2df2805 100755 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -1937,6 +1937,7 @@ def check_backend(backend): 'auto', 'heter', 'xccl', + 'flagcx', ]: raise ValueError( "paddle.distributed initialize error, " @@ -1957,6 +1958,12 @@ def check_backend(backend): "your paddle is not compiled with xpu but you assign 'bkcl' as backend." ) + if backend == 'flagcx' and not framework.core.is_compiled_with_flagcx(): + raise ValueError( + "paddle.distributed initialize error, " + "your paddle is not compiled with flagcx but you assign 'flagcx' as backend." + ) + def block_windows_and_macos(backend): if backend != 'gloo': diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 61879ddb1663f..cb8b661fa122a 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -932,7 +932,7 @@ def _start_kv_server(port, http_server_d, size): def _is_cpuonly(backend): check_backend(backend) if ( - backend in ['auto', 'nccl', 'bkcl', 'heter'] + backend in ['auto', 'nccl', 'bkcl', 'heter', 'flagcx'] and (core.is_compiled_with_cuda() or core.is_compiled_with_xpu()) ) or backend == 'xccl': # passes 'auto' and can use cuda or xpu, use the default logics. so return False @@ -1134,7 +1134,7 @@ def init_parallel_env() -> Group: default_store = core.create_or_get_global_tcp_store() _set_default_store(default_store) - if backend in ["nccl", 'xccl', 'bkcl']: + if backend in ["nccl", 'xccl', 'bkcl', 'flagcx']: core.CommContextManager.set_device_id(parallel_env.device_id) pg = _new_process_group_impl( diff --git a/python/paddle/jit/sot/utils/paddle_api_config.py b/python/paddle/jit/sot/utils/paddle_api_config.py index e872dd4ee200b..6aad6f30a9330 100644 --- a/python/paddle/jit/sot/utils/paddle_api_config.py +++ b/python/paddle/jit/sot/utils/paddle_api_config.py @@ -163,5 +163,6 @@ def is_directly_run_api(api): paddle.base.libpaddle.is_compiled_with_distribute, paddle.base.libpaddle.is_compiled_with_brpc, paddle.base.libpaddle.is_compiled_with_dist, + paddle.base.libpaddle.is_compiled_with_flagcx, } return api in NATIVE_CODE_PURE_FUNCTIONS diff --git a/python/setup.py.in b/python/setup.py.in index caeef2be4b77c..10d54865f0818 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -960,6 +960,9 @@ package_dir={ libs_path='${PADDLE_BINARY_DIR}/python/paddle/libs' package_data['paddle.libs']= [] +if('${WITH_FLAGCX}' == 'ON'): + package_data['paddle.libs'] += [('libflagcx' if os.name != 'nt' else 'flagcx') + ext_name] + shutil.copy('${FLAGCX_LIB}', libs_path) if('${WITH_SHARED_PHI}' == 'ON'): package_data['paddle.libs'] += [('libphi' if os.name != 'nt' else 'phi') + ext_name] shutil.copy('${PHI_LIB}', libs_path) diff --git a/setup.py b/setup.py index d41cf89c749a1..7035848895a7c 100644 --- a/setup.py +++ b/setup.py @@ -1339,6 +1339,11 @@ def get_package_data_and_package_dir(): # put all thirdparty libraries in paddle.libs libs_path = paddle_binary_dir + '/python/paddle/libs' package_data['paddle.libs'] = [] + if env_dict.get("WITH_FLAGCX") == 'ON': + package_data['paddle.libs'] += [ + ('libflagcx' if os.name != 'nt' else 'flagcx') + ext_suffix + ] + shutil.copy(env_dict.get("FLAGCX_LIB"), libs_path) if env_dict.get("WITH_SHARED_PHI") == "ON": package_data['paddle.libs'] += [ ('libphi' if os.name != 'nt' else 'phi') + ext_suffix diff --git a/test/collective/CMakeLists.txt b/test/collective/CMakeLists.txt index d5a699b7df67d..286e672fc5414 100644 --- a/test/collective/CMakeLists.txt +++ b/test/collective/CMakeLists.txt @@ -87,7 +87,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;FLAGS_enable_pir_api=0" ) set_tests_properties(test_collective_barrier_api - PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "450" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) bash_test_modules( @@ -170,7 +170,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_isend_irecv_api MODULES test_collective_isend_irecv_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_isend_irecv_api - PROPERTIES TIMEOUT "160" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "320" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -268,7 +268,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_collective_gather_api MODULES test_collective_gather_api ENVS "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python") set_tests_properties(test_collective_gather_api - PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "360" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -279,7 +279,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) "http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;FLAGS_enable_pir_api=0" ) set_tests_properties(test_collective_sendrecv_api - PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "1000" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -295,7 +295,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_allgather_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_allgather_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -303,7 +303,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_allreduce_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_allreduce_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -311,7 +311,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_alltoall_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_alltoall_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -319,7 +319,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_alltoall_single_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_alltoall_single_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -327,7 +327,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_broadcast_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_broadcast_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -335,7 +335,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_reduce_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_reduce_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -343,7 +343,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_reduce_scatter_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_reduce_scatter_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -351,7 +351,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_scatter_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_scatter_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( @@ -359,7 +359,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) test_communication_stream_sendrecv_api ENVS "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") set_tests_properties(test_communication_stream_sendrecv_api - PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") + PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( diff --git a/test/collective/communication_stream_allgather_api_dygraph.py b/test/collective/communication_stream_allgather_api_dygraph.py index cc30358dccd50..4b018fc210941 100644 --- a/test/collective/communication_stream_allgather_api_dygraph.py +++ b/test/collective/communication_stream_allgather_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_allreduce_api_dygraph.py b/test/collective/communication_stream_allreduce_api_dygraph.py index 07651e22fac95..13c5e3ee020b6 100644 --- a/test/collective/communication_stream_allreduce_api_dygraph.py +++ b/test/collective/communication_stream_allreduce_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_alltoall_api_dygraph.py b/test/collective/communication_stream_alltoall_api_dygraph.py index eebad92e46cd4..5907bbbf63718 100644 --- a/test/collective/communication_stream_alltoall_api_dygraph.py +++ b/test/collective/communication_stream_alltoall_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_alltoall_single_api_dygraph.py b/test/collective/communication_stream_alltoall_single_api_dygraph.py index 2742fac67ad96..b6ba977e84f33 100644 --- a/test/collective/communication_stream_alltoall_single_api_dygraph.py +++ b/test/collective/communication_stream_alltoall_single_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_broadcast_api_dygraph.py b/test/collective/communication_stream_broadcast_api_dygraph.py index 59a734cef89d8..817fc726e58ba 100644 --- a/test/collective/communication_stream_broadcast_api_dygraph.py +++ b/test/collective/communication_stream_broadcast_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_reduce_api_dygraph.py b/test/collective/communication_stream_reduce_api_dygraph.py index 4233fa80e109c..27740315cf427 100644 --- a/test/collective/communication_stream_reduce_api_dygraph.py +++ b/test/collective/communication_stream_reduce_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_reduce_scatter_api_dygraph.py b/test/collective/communication_stream_reduce_scatter_api_dygraph.py index d8bb928c46010..4d94ed52ead12 100644 --- a/test/collective/communication_stream_reduce_scatter_api_dygraph.py +++ b/test/collective/communication_stream_reduce_scatter_api_dygraph.py @@ -32,7 +32,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_scatter_api_dygraph.py b/test/collective/communication_stream_scatter_api_dygraph.py index 85c7acbb6c6ba..4474c6cc85152 100644 --- a/test/collective/communication_stream_scatter_api_dygraph.py +++ b/test/collective/communication_stream_scatter_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/communication_stream_sendrecv_api_dygraph.py b/test/collective/communication_stream_sendrecv_api_dygraph.py index c8d2d7e780409..f3f17905c92ae 100644 --- a/test/collective/communication_stream_sendrecv_api_dygraph.py +++ b/test/collective/communication_stream_sendrecv_api_dygraph.py @@ -29,7 +29,7 @@ def __init__(self): self._shape = eval(os.getenv("shape")) self._dtype = os.getenv("dtype") self._seeds = eval(os.getenv("seeds")) - if self._backend not in ["nccl", "gloo"]: + if self._backend not in ["nccl", "gloo", "flagcx"]: raise NotImplementedError( "Only support nccl and gloo as the backend for now." ) diff --git a/test/collective/test_collective_allgather_api.py b/test/collective/test_collective_allgather_api.py index 3a25aa341e27b..d49bf8328bac0 100644 --- a/test/collective/test_collective_allgather_api.py +++ b/test/collective/test_collective_allgather_api.py @@ -44,6 +44,26 @@ def test_allgather_nccl(self): dtype=dtype, ) + def test_allgather_flagcx(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_allgather_api.py", + "allgather", + "flagcx", + dtype=dtype, + ) + def test_allgather_gloo(self): dtypes_to_test = [ "float16", @@ -86,6 +106,27 @@ def test_allgather_nccl_dygraph(self): dtype=dtype, ) + def test_allgather_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_allgather_api_dygraph.py", + "allgather", + "flagcx", + static_mode="0", + dtype=dtype, + ) + def test_allgather_nccl_dygraph_with_trace_hang(self): dtypes_to_test = [ "float32", diff --git a/test/collective/test_collective_allreduce_api.py b/test/collective/test_collective_allreduce_api.py index 49cf8448ee642..ef57456ebdc8e 100644 --- a/test/collective/test_collective_allreduce_api.py +++ b/test/collective/test_collective_allreduce_api.py @@ -163,6 +163,29 @@ def test_allreduce_gloo_dygraph(self): dtype=dtype, ) + def test_allreduce_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if self._nccl_version >= 21000: + dtypes_to_test.append("bfloat16") + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_allreduce_api_dygraph.py", + "allreduce", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_alltoall_api.py b/test/collective/test_collective_alltoall_api.py index ac6c2f87493d3..e7248d5396984 100644 --- a/test/collective/test_collective_alltoall_api.py +++ b/test/collective/test_collective_alltoall_api.py @@ -105,6 +105,27 @@ def test_alltoall_unequal_split_nccl_dygraph(self): dtype=dtype, ) + def test_alltoall_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_alltoall_api_dygraph.py", + "alltoall", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_barrier_api.py b/test/collective/test_collective_barrier_api.py index 6bdcd4fadfc01..0c21f89017039 100644 --- a/test/collective/test_collective_barrier_api.py +++ b/test/collective/test_collective_barrier_api.py @@ -33,6 +33,12 @@ def test_barrier_gloo(self): "collective_barrier_api.py", "barrier", "gloo", "5" ) + def test_barrier_flagcx(self): + if paddle.base.core.is_compiled_with_flagcx(): + self.check_with_place( + "collective_barrier_api.py", "barrier", "flagcx", static_mode="0" + ) + if __name__ == '__main__': unittest.main() diff --git a/test/collective/test_collective_broadcast_api.py b/test/collective/test_collective_broadcast_api.py index a6d596fcc5e50..fa797e666525d 100644 --- a/test/collective/test_collective_broadcast_api.py +++ b/test/collective/test_collective_broadcast_api.py @@ -125,6 +125,34 @@ def test_broadcast_gloo_dygraph(self): dtype=dtype, ) + def test_broadcast_flagcx(self): + self.check_with_place( + "collective_broadcast_api.py", + "broadcast", + "flagcx", + ) + + def test_broadcast_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_broadcast_api_dygraph.py", + "broadcast", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_gather_api.py b/test/collective/test_collective_gather_api.py index 8f2c24a3dac34..6d2cc3f508cdd 100644 --- a/test/collective/test_collective_gather_api.py +++ b/test/collective/test_collective_gather_api.py @@ -47,6 +47,27 @@ def test_gather_nccl_dygraph(self): dtype=dtype, ) + def test_gather_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_gather_api_dygraph.py", + "gather", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_isend_irecv_api.py b/test/collective/test_collective_isend_irecv_api.py index ba1a8897a78de..54073700257ce 100644 --- a/test/collective/test_collective_isend_irecv_api.py +++ b/test/collective/test_collective_isend_irecv_api.py @@ -16,6 +16,8 @@ import legacy_test.test_collective_api_base as test_base +import paddle + class TestCollectiveIsendIrecvAPI(test_base.TestDistBase): def _setup_config(self): @@ -43,6 +45,27 @@ def test_isend_irecv_nccl_dygraph(self): dtype=dtype, ) + def test_isend_irecv_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_isend_irecv_api_dygraph.py", + "sendrecv", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_reduce_api.py b/test/collective/test_collective_reduce_api.py index a437b92c541f1..1f185fde6913b 100644 --- a/test/collective/test_collective_reduce_api.py +++ b/test/collective/test_collective_reduce_api.py @@ -164,6 +164,30 @@ def test_reduce_gloo_dygraph(self): dtype=dtype, ) + def test_reduce_flagcx(self): + self.check_with_place("collective_reduce_api.py", "reduce", "flagcx") + + def test_reduce_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_reduce_api_dygraph.py", + "reduce", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_reduce_scatter_api.py b/test/collective/test_collective_reduce_scatter_api.py index 6afe0d4601ec6..8eb37c36e8836 100644 --- a/test/collective/test_collective_reduce_scatter_api.py +++ b/test/collective/test_collective_reduce_scatter_api.py @@ -16,6 +16,8 @@ import legacy_test.test_collective_api_base as test_base +import paddle + class TestCollectiveReduceScatterAPI(test_base.TestDistBase): def _setup_config(self): @@ -61,6 +63,27 @@ def test_reduce_scatter_nccl_dygraph(self): dtype=dtype, ) + def test_reduce_scatter_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_reduce_scatter_api_dygraph.py", + "reduce_scatter", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_scatter_api.py b/test/collective/test_collective_scatter_api.py index 7d50c909c5ada..f5bc62ea84c1b 100644 --- a/test/collective/test_collective_scatter_api.py +++ b/test/collective/test_collective_scatter_api.py @@ -90,6 +90,44 @@ def test_scatter_gloo_dygraph(self): dtype=dtype, ) + def test_scatter_flagcx(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_scatter_api.py", + "scatter", + "flagcx", + dtype=dtype, + ) + + def test_scatter_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_scatter_api_dygraph.py", + "scatter", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_collective_sendrecv_api.py b/test/collective/test_collective_sendrecv_api.py index 7fe9e571c5dee..870e689a6ed25 100644 --- a/test/collective/test_collective_sendrecv_api.py +++ b/test/collective/test_collective_sendrecv_api.py @@ -75,6 +75,27 @@ def test_sendrecv_nccl_dygraph(self): dtype=dtype, ) + def test_sendrecv_flagcx_dygraph(self): + dtypes_to_test = [ + "float16", + "float32", + "float64", + "int32", + "int64", + "int8", + "uint8", + "bool", + ] + if paddle.base.core.is_compiled_with_flagcx(): + for dtype in dtypes_to_test: + self.check_with_place( + "collective_sendrecv_api_dygraph.py", + "sendrecv", + "flagcx", + static_mode="0", + dtype=dtype, + ) + if __name__ == "__main__": unittest.main() diff --git a/test/collective/test_communication_stream_allgather_api.py b/test/collective/test_communication_stream_allgather_api.py index e55bdb963cf6b..fd61e0ab8ad69 100644 --- a/test/collective/test_communication_stream_allgather_api.py +++ b/test/collective/test_communication_stream_allgather_api.py @@ -16,17 +16,22 @@ import test_communication_api_base as test_base +import paddle + class TestCommunicationStreamAllgatherAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_allreduce_api.py b/test/collective/test_communication_stream_allreduce_api.py index 60386a6262ff2..807200a3777fe 100644 --- a/test/collective/test_communication_stream_allreduce_api.py +++ b/test/collective/test_communication_stream_allreduce_api.py @@ -16,17 +16,22 @@ import test_communication_api_base as test_base +import paddle + class TestCommunicationStreamAllreduceAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_alltoall_api.py b/test/collective/test_communication_stream_alltoall_api.py index af75bec3884c7..b190e05d3e4dc 100644 --- a/test/collective/test_communication_stream_alltoall_api.py +++ b/test/collective/test_communication_stream_alltoall_api.py @@ -16,17 +16,22 @@ import test_communication_api_base as test_base +import paddle + class TestCommunicationStreamAllToAllAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_alltoall_single_api.py b/test/collective/test_communication_stream_alltoall_single_api.py index 7509d6609d9c6..5f20575d4e419 100644 --- a/test/collective/test_communication_stream_alltoall_single_api.py +++ b/test/collective/test_communication_stream_alltoall_single_api.py @@ -16,6 +16,8 @@ import test_communication_api_base as test_base +import paddle + class TestCommunicationStreamAllToAllSingleAPI( test_base.CommunicationTestDistBase @@ -23,12 +25,15 @@ class TestCommunicationStreamAllToAllSingleAPI( def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_broadcast_api.py b/test/collective/test_communication_stream_broadcast_api.py index e6228297dc610..6acc5b41d6cb6 100644 --- a/test/collective/test_communication_stream_broadcast_api.py +++ b/test/collective/test_communication_stream_broadcast_api.py @@ -18,18 +18,22 @@ sys.path.append("../legacy_test") import test_communication_api_base as test_base +import paddle class TestCommunicationStreamBroadcastAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_reduce_api.py b/test/collective/test_communication_stream_reduce_api.py index df39b1ff294c1..defd533a138e3 100644 --- a/test/collective/test_communication_stream_reduce_api.py +++ b/test/collective/test_communication_stream_reduce_api.py @@ -15,18 +15,22 @@ import unittest import test_communication_api_base as test_base +import paddle class TestCommunicationStreamReduceAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_reduce_scatter_api.py b/test/collective/test_communication_stream_reduce_scatter_api.py index 3b0c8c7664984..73fab96dcd89d 100644 --- a/test/collective/test_communication_stream_reduce_scatter_api.py +++ b/test/collective/test_communication_stream_reduce_scatter_api.py @@ -15,6 +15,7 @@ import unittest import test_communication_api_base as test_base +import paddle class TestCommunicationStreamReduceScatterAPI( @@ -23,12 +24,15 @@ class TestCommunicationStreamReduceScatterAPI( def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_scatter_api.py b/test/collective/test_communication_stream_scatter_api.py index be08cbfe58202..f03570e1c3638 100644 --- a/test/collective/test_communication_stream_scatter_api.py +++ b/test/collective/test_communication_stream_scatter_api.py @@ -15,18 +15,22 @@ import unittest import test_communication_api_base as test_base +import paddle class TestCommunicationStreamScatterAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/collective/test_communication_stream_sendrecv_api.py b/test/collective/test_communication_stream_sendrecv_api.py index 048ae08df2fe1..229507612486b 100644 --- a/test/collective/test_communication_stream_sendrecv_api.py +++ b/test/collective/test_communication_stream_sendrecv_api.py @@ -15,18 +15,22 @@ import unittest import test_communication_api_base as test_base +import paddle class TestCommunicationStreamSendRecvAPI(test_base.CommunicationTestDistBase): def setUp(self): super().setUp(num_of_devices=2, timeout=120) self._default_envs = { - "backend": "nccl", "shape": "(100, 200)", "dtype": "float32", "seeds": str(self._seeds), } + backend_list = ["nccl"] + if paddle.base.core.is_compiled_with_flagcx(): + backend_list.append("flagcx") self._changeable_envs = { + "backend": backend_list, "sync_op": ["True", "False"], "use_calc_stream": ["True", "False"], } diff --git a/test/legacy_test/test_collective_api_base.py b/test/legacy_test/test_collective_api_base.py index 81087219e589b..6b2dd59665753 100644 --- a/test/legacy_test/test_collective_api_base.py +++ b/test/legacy_test/test_collective_api_base.py @@ -335,7 +335,7 @@ def check_with_place( dtype=None, reduce_type=None, ): - if backend == "nccl" or backend == "bkcl": + if backend == "nccl" or backend == "bkcl" or backend == "flagcx": with_gloo = '0' else: with_gloo = '1' diff --git a/third_party/flagcx b/third_party/flagcx new file mode 160000 index 0000000000000..7e6c4cc3cad3f --- /dev/null +++ b/third_party/flagcx @@ -0,0 +1 @@ +Subproject commit 7e6c4cc3cad3fce9b3dedfe46a9d195d616e8ffa