Skip to content

Commit 871dc43

Browse files
committed
[Comm] fix to support comm init for inference in phi comm ops
1 parent 9240c25 commit 871dc43

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

paddle/fluid/framework/CMakeLists.txt

+22
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,28 @@ if(WITH_XPU)
289289
common
290290
op_compat_infos
291291
type_info)
292+
elseif(WITH_NCCL OR WITH_RCCL)
293+
cc_library(
294+
operator
295+
SRCS operator.cc transfer_scope_cache.cc unused_var_check.cc
296+
infershape_utils.cc
297+
DEPS op_info
298+
proto_desc
299+
tensor
300+
scope
301+
glog
302+
shape_inference
303+
data_transform
304+
lod_tensor
305+
op_kernel_type
306+
op_call_stack
307+
detail_op_handle
308+
phi_utils
309+
phi
310+
common
311+
op_compat_infos
312+
type_info
313+
process_group_nccl)
292314
else()
293315
cc_library(
294316
operator

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

+38
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
#include "paddle/phi/core/kernel_context.h"
4040
#include "paddle/phi/core/kernel_factory.h"
4141
#include "paddle/phi/core/memory/stats.h"
42+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
43+
#include "paddle/fluid/distributed/collective/process_group.h"
44+
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
45+
#endif
4246

4347
#ifdef PADDLE_WITH_DNNL
4448
#include "paddle/fluid/platform/onednn_helper.h"
@@ -865,6 +869,40 @@ void BuildOpFuncList(const phi::Place& place,
865869
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
866870
phi::KernelRegisteredType::FUNCTION) {
867871
VLOG(6) << op_type << " run function kernel";
872+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
873+
auto attrs = op->Attrs();
874+
if (attrs.find("ring_id") != attrs.end()) {
875+
auto ring_id_attr = attrs.at("ring_id");
876+
int ring_id = PADDLE_GET(int, ring_id_attr);
877+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
878+
if (map->has(ring_id)) {
879+
auto original_stream =
880+
static_cast<phi::GPUContext*>(dev_ctx)->cuda_stream();
881+
distributed::ProcessGroup* pg = map->get(ring_id);
882+
auto comm_context =
883+
static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
884+
->GetOrCreateCommContext(place);
885+
dev_ctx =
886+
static_cast<phi::distributed::NCCLCommContext*>(comm_context)
887+
->GetDevContext();
888+
dev_ctx->SetCommContext(comm_context);
889+
890+
static_cast<phi::GPUContext*>(dev_ctx)->SetCUDAStream(
891+
original_stream, false);
892+
auto& instance =
893+
paddle::memory::allocation::AllocatorFacade::Instance();
894+
dev_ctx->SetAllocator(
895+
instance
896+
.GetAllocator(
897+
place,
898+
static_cast<phi::GPUContext*>(dev_ctx)->stream())
899+
.get());
900+
} else {
901+
VLOG(3) << "ring_id " << ring_id
902+
<< " not found in ProcessGroupMapFromGid ";
903+
}
904+
}
905+
#endif
868906
if (static_build) {
869907
FakeInitializeOutputsForFunctionKernel(
870908
*op,

paddle/fluid/framework/new_executor/program_interpreter.cc

+35-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#include "paddle/phi/core/platform/cuda_graph_with_memory_pool.h"
3535
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
3636
#include "paddle/common/flags.h"
37+
#include "paddle/fluid/distributed/collective/process_group.h"
38+
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
3739
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
3840
#include "paddle/phi/core/distributed/comm_context_manager.h"
3941
#include "paddle/phi/core/distributed/nccl_comm_context.h"
@@ -139,6 +141,7 @@ ProgramInterpreter::~ProgramInterpreter() {
139141
}
140142

141143
void ProgramInterpreter::RunImpl() {
144+
VLOG(2) << "[liyamei ProgramInterpreter] start RunImpl";
142145
// lazy initialization of gc, do not create gc is the program only run once
143146
if (!gc_) {
144147
gc_ = CreateInterpreterCoreGarbageCollector(place_, vec_instruction_);
@@ -150,12 +153,14 @@ void ProgramInterpreter::RunImpl() {
150153
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
151154
(sync_op_num_ == 0))) {
152155
VLOG(4) << "Tracing Instruction List";
156+
VLOG(2) << "[liyamei ProgramInterpreter] Tracing Instruction List";
153157
TraceInstructionList(vec_instruction_);
154158
} else {
155159
VLOG(4) << "Non-tracing";
156160
// For the program that only run once, it is no need to
157161
// create work_queue, so the async_work_queue_ is created
158162
// until the second step run.
163+
VLOG(2) << "[liyamei ProgramInterpreter] Non-tracing";
159164
async_work_queue_ = GetWorkQueue();
160165
ExecuteInstructionList(vec_instruction_);
161166
}
@@ -927,7 +932,8 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
927932
auto place = instr_node.DeviceContext().GetPlace();
928933
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
929934
: var_scope_.GetMutableScope();
930-
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
935+
VLOG(2) << "[liyamei RunOperator] Start run " << place << " "
936+
<< op->DebugStringEx(local_scope);
931937

932938
if (execution_config_.used_for_inference) {
933939
for (auto& hook : input_hookfuncs_) {
@@ -1010,15 +1016,39 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
10101016
VLOG(4) << "Run function kernel: " << op->Type();
10111017
VLOG(4) << instr_node.InnerRuntimeContext().get() << " "
10121018
<< &instr_node.DeviceContext();
1019+
1020+
auto dev_ctx =
1021+
const_cast<phi::DeviceContext*>(&instr_node.DeviceContext());
1022+
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1023+
auto attrs = op->Attrs();
1024+
if (attrs.find("ring_id") != attrs.end()) {
1025+
auto ring_id_attr = attrs.at("ring_id");
1026+
int ring_id = PADDLE_GET(int, ring_id_attr);
1027+
auto map = distributed::ProcessGroupMapFromGid::getInstance();
1028+
if (map->has(ring_id)) {
1029+
distributed::ProcessGroup* pg = map->get(ring_id);
1030+
auto comm_context =
1031+
static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)
1032+
->GetOrCreateCommContext(place);
1033+
dev_ctx =
1034+
static_cast<phi::distributed::NCCLCommContext*>(comm_context)
1035+
->GetDevContext();
1036+
dev_ctx->SetCommContext(comm_context);
1037+
} else {
1038+
VLOG(3) << "ring_id " << ring_id
1039+
<< " not found in ProcessGroupMapFromGid ";
1040+
}
1041+
}
1042+
#endif
10131043
phi::KernelContext phi_kernel_context;
10141044
op_with_kernel->BuildPhiKernelContext(
10151045
*instr_node.InnerRuntimeContext().get(),
1016-
const_cast<phi::DeviceContext*>(&instr_node.DeviceContext()),
1046+
dev_ctx,
10171047
&phi_kernel_context);
10181048

10191049
(*kernel)(&phi_kernel_context);
10201050
} else {
1021-
VLOG(4) << "Run structure kernel: " << op->Type();
1051+
VLOG(2) << "Run structure kernel: " << op->Type();
10221052
(*kernel)(instr_node.InnerExecutionContext().get());
10231053
}
10241054
} else { // fluid kernel
@@ -1148,7 +1178,7 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
11481178
}
11491179

11501180
void ProgramInterpreter::RunInstruction(const Instruction& instr_node) {
1151-
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
1181+
VLOG(2) << __func__ << " OP id:" << instr_node.Id()
11521182
<< " name:" << instr_node.OpBase()->Type() << " type:"
11531183
<< (instr_node.KernelType() == OpFuncType::kCpuSync
11541184
? "kCpuSync"
@@ -1603,6 +1633,7 @@ bool ProgramInterpreter::HasLocalScope() const {
16031633
// KQueueSync Ops is 0, we choose Trace mode.
16041634
void ProgramInterpreter::TraceInstructionList(
16051635
const std::vector<Instruction>& vec_instr) {
1636+
VLOG(2) << "[liyamei ProgramInterpreter] start TraceInstructionList";
16061637
unfinished_op_number_ = vec_instr.size();
16071638
if (unfinished_op_number_ == 0) {
16081639
VLOG(4) << "No op to run, return";

0 commit comments

Comments
 (0)