diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f8a2704f0033d6..7bef3485feb84c 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -289,6 +289,28 @@ if(WITH_XPU) common op_compat_infos type_info) +elseif(WITH_NCCL OR WITH_RCCL) + cc_library( + operator + SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc + infershape_utils.cc + DEPS op_info + proto_desc + tensor + scope + glog + shape_inference + data_transform + lod_tensor + op_kernel_type + op_call_stack + detail_op_handle + phi_utils + phi + common + op_compat_infos + type_info + process_group_nccl) else() cc_library( operator diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 0fdc4fb93cd70a..e078b9c2930808 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -39,6 +39,10 @@ #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/memory/stats.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_nccl.h" +#endif #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/platform/onednn_helper.h" @@ -865,6 +869,40 @@ void BuildOpFuncList(const phi::Place& place, op_func_node.phi_kernel_->GetKernelRegisteredType() == phi::KernelRegisteredType::FUNCTION) { VLOG(6) << op_type << " run function kernel"; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto attrs = op->Attrs(); + if (attrs.find("ring_id") != attrs.end()) { + auto ring_id_attr = attrs.at("ring_id"); + int ring_id = PADDLE_GET(int, ring_id_attr); + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(ring_id)) { + auto original_stream = + static_cast(dev_ctx)->cuda_stream(); + distributed::ProcessGroup* pg = map->get(ring_id); + auto comm_context = + static_cast(pg) + ->GetOrCreateCommContext(place); + dev_ctx = + static_cast(comm_context) + ->GetDevContext(); + dev_ctx->SetCommContext(comm_context); + + static_cast(dev_ctx)->SetCUDAStream( + original_stream, false); + auto& instance = + paddle::memory::allocation::AllocatorFacade::Instance(); + dev_ctx->SetAllocator( + instance + .GetAllocator( + place, + static_cast(dev_ctx)->stream()) + .get()); + } else { + VLOG(3) << "ring_id " << ring_id + << " not found in ProcessGroupMapFromGid "; + } + } +#endif if (static_build) { FakeInitializeOutputsForFunctionKernel( *op, diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 36481a9e77ad21..c854558f65d09c 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -34,6 +34,8 @@ #include "paddle/phi/core/platform/cuda_graph_with_memory_pool.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/common/flags.h" +#include "paddle/fluid/distributed/collective/process_group.h" +#include "paddle/fluid/distributed/collective/process_group_nccl.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" @@ -1010,10 +1012,34 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) { VLOG(4) << "Run function kernel: " << op->Type(); VLOG(4) << instr_node.InnerRuntimeContext().get() << " " << &instr_node.DeviceContext(); + + auto dev_ctx = + const_cast(&instr_node.DeviceContext()); +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto attrs = op->Attrs(); + if (attrs.find("ring_id") != attrs.end()) { + auto ring_id_attr = attrs.at("ring_id"); + int ring_id = PADDLE_GET(int, ring_id_attr); + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(ring_id)) { + distributed::ProcessGroup* pg = map->get(ring_id); + auto comm_context = + static_cast(pg) + ->GetOrCreateCommContext(place); + dev_ctx = + static_cast(comm_context) + ->GetDevContext(); + dev_ctx->SetCommContext(comm_context); + } else { + VLOG(3) << "ring_id " << ring_id + << " not found in ProcessGroupMapFromGid "; + } + } +#endif phi::KernelContext phi_kernel_context; op_with_kernel->BuildPhiKernelContext( *instr_node.InnerRuntimeContext().get(), - const_cast(&instr_node.DeviceContext()), + dev_ctx, &phi_kernel_context); (*kernel)(&phi_kernel_context);