Skip to content

Commit 8089550

Browse files
authored
[Comm] fix to support comm init for inference in phi comm ops (#69653)
1 parent 6fd300b commit 8089550

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

paddle/fluid/framework/CMakeLists.txt

+22
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,28 @@ if(WITH_XPU)
289289
common
290290
op_compat_infos
291291
type_info)
292+
elseif(WITH_NCCL OR WITH_RCCL)
293+
cc_library(
294+
operator
295+
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
296+
infershape_utils.cc
297+
DEPS op_info
298+
proto_desc
299+
tensor
300+
scope
301+
glog
302+
shape_inference
303+
data_transform
304+
lod_tensor
305+
op_kernel_type
306+
op_call_stack
307+
detail_op_handle
308+
phi_utils
309+
phi
310+
common
311+
op_compat_infos
312+
type_info
313+
process_group_nccl)
292314
else()
293315
cc_library(
294316
operator

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

+38
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
#include "paddle/phi/core/kernel_context.h"
4040
#include "paddle/phi/core/kernel_factory.h"
4141
#include "paddle/phi/core/memory/stats.h"
42+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
43+
#include "paddle/fluid/distributed/collective/process_group.h"
44+
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
45+
#endif
4246

4347
#ifdef PADDLE_WITH_DNNL
4448
#include "paddle/fluid/platform/onednn_helper.h"
@@ -865,6 +869,40 @@ void BuildOpFuncList(const phi::Place& place,
865869
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
866870
phi::KernelRegisteredType::FUNCTION) {
867871
VLOG(6) << op_type << " run function kernel";
872+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
873+
auto attrs = op->Attrs();
874+
if (attrs.find("ring_id") != attrs.end()) {
875+
auto ring_id_attr = attrs.at("ring_id");
876+
int ring_id = PADDLE_GET(int, ring_id_attr);
877+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
878+
if (map->has(ring_id)) {
879+
auto original_stream =
880+
static_cast<phi::GPUContext*>(dev_ctx)->cuda_stream();
881+
distributed::ProcessGroup* pg = map->get(ring_id);
882+
auto comm_context =
883+
static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
884+
->GetOrCreateCommContext(place);
885+
dev_ctx =
886+
static_cast<phi::distributed::NCCLCommContext*>(comm_context)
887+
->GetDevContext();
888+
dev_ctx->SetCommContext(comm_context);
889+
890+
static_cast<phi::GPUContext*>(dev_ctx)->SetCUDAStream(
891+
original_stream, false);
892+
auto& instance =
893+
paddle::memory::allocation::AllocatorFacade::Instance();
894+
dev_ctx->SetAllocator(
895+
instance
896+
.GetAllocator(
897+
place,
898+
static_cast<phi::GPUContext*>(dev_ctx)->stream())
899+
.get());
900+
} else {
901+
VLOG(3) << "ring_id " << ring_id
902+
<< " not found in ProcessGroupMapFromGid ";
903+
}
904+
}
905+
#endif
868906
if (static_build) {
869907
FakeInitializeOutputsForFunctionKernel(
870908
*op,

paddle/fluid/framework/new_executor/program_interpreter.cc

+27-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include "paddle/phi/core/platform/cuda_graph_with_memory_pool.h"
3535
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3636
#include "paddle/common/flags.h"
37+
#include "paddle/fluid/distributed/collective/process_group.h"
38+
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
3739
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
3840
#include "paddle/phi/core/distributed/comm_context_manager.h"
3941
#include "paddle/phi/core/distributed/nccl_comm_context.h"
@@ -1010,10 +1012,34 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
10101012
VLOG(4) << "Run function kernel: " << op->Type();
10111013
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
10121014
<< &instr_node.DeviceContext();
1015+
1016+
auto dev_ctx =
1017+
const_cast<phi::DeviceContext*>(&instr_node.DeviceContext());
1018+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1019+
auto attrs = op->Attrs();
1020+
if (attrs.find("ring_id") != attrs.end()) {
1021+
auto ring_id_attr = attrs.at("ring_id");
1022+
int ring_id = PADDLE_GET(int, ring_id_attr);
1023+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
1024+
if (map->has(ring_id)) {
1025+
distributed::ProcessGroup* pg = map->get(ring_id);
1026+
auto comm_context =
1027+
static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
1028+
->GetOrCreateCommContext(place);
1029+
dev_ctx =
1030+
static_cast<phi::distributed::NCCLCommContext*>(comm_context)
1031+
->GetDevContext();
1032+
dev_ctx->SetCommContext(comm_context);
1033+
} else {
1034+
VLOG(3) << "ring_id " << ring_id
1035+
<< " not found in ProcessGroupMapFromGid ";
1036+
}
1037+
}
1038+
#endif
10131039
phi::KernelContext phi_kernel_context;
10141040
op_with_kernel->BuildPhiKernelContext(
10151041
*instr_node.InnerRuntimeContext().get(),
1016-
const_cast<phi::DeviceContext*>(&instr_node.DeviceContext()),
1042+
dev_ctx,
10171043
&phi_kernel_context);
10181044

10191045
(*kernel)(&phi_kernel_context);

0 commit comments

Comments
 (0)