34
34
#include " paddle/phi/core/platform/cuda_graph_with_memory_pool.h"
35
35
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
36
36
#include " paddle/common/flags.h"
37
+ #include " paddle/fluid/distributed/collective/process_group.h"
38
+ #include " paddle/fluid/distributed/collective/process_group_nccl.h"
37
39
#include " paddle/fluid/platform/device/gpu/nccl_helper.h"
38
40
#include " paddle/phi/core/distributed/comm_context_manager.h"
39
41
#include " paddle/phi/core/distributed/nccl_comm_context.h"
@@ -139,6 +141,7 @@ ProgramInterpreter::~ProgramInterpreter() {
139
141
}
140
142
141
143
void ProgramInterpreter::RunImpl () {
144
+ VLOG (2 ) << " [liyamei ProgramInterpreter] start RunImpl" ;
142
145
// lazy initialization of gc, do not create gc is the program only run once
143
146
if (!gc_) {
144
147
gc_ = CreateInterpreterCoreGarbageCollector (place_, vec_instruction_);
@@ -150,12 +153,14 @@ void ProgramInterpreter::RunImpl() {
150
153
((execution_config_.used_for_jit || execution_config_.used_for_cinn ) &&
151
154
(sync_op_num_ == 0 ))) {
152
155
VLOG (4 ) << " Tracing Instruction List" ;
156
+ VLOG (2 ) << " [liyamei ProgramInterpreter] Tracing Instruction List" ;
153
157
TraceInstructionList (vec_instruction_);
154
158
} else {
155
159
VLOG (4 ) << " Non-tracing" ;
156
160
// For the program that only run once, it is no need to
157
161
// create work_queue, so the async_work_queue_ is created
158
162
// until the second step run.
163
+ VLOG (2 ) << " [liyamei ProgramInterpreter] Non-tracing" ;
159
164
async_work_queue_ = GetWorkQueue ();
160
165
ExecuteInstructionList (vec_instruction_);
161
166
}
@@ -927,7 +932,8 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
927
932
auto place = instr_node.DeviceContext ().GetPlace ();
928
933
Scope* local_scope = HasLocalScope () ? var_scope_.GetMutableLocalScope ()
929
934
: var_scope_.GetMutableScope ();
930
- VLOG (4 ) << " Start run " << place << " " << op->DebugStringEx (local_scope);
935
+ VLOG (2 ) << " [liyamei RunOperator] Start run " << place << " "
936
+ << op->DebugStringEx (local_scope);
931
937
932
938
if (execution_config_.used_for_inference ) {
933
939
for (auto & hook : input_hookfuncs_) {
@@ -1010,15 +1016,39 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
1010
1016
VLOG (4 ) << " Run function kernel: " << op->Type ();
1011
1017
VLOG (4 ) << instr_node.InnerRuntimeContext ().get () << " "
1012
1018
<< &instr_node.DeviceContext ();
1019
+
1020
+ auto dev_ctx =
1021
+ const_cast <phi::DeviceContext*>(&instr_node.DeviceContext ());
1022
+ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1023
+ auto attrs = op->Attrs ();
1024
+ if (attrs.find (" ring_id" ) != attrs.end ()) {
1025
+ auto ring_id_attr = attrs.at (" ring_id" );
1026
+ int ring_id = PADDLE_GET (int , ring_id_attr);
1027
+ auto map = distributed::ProcessGroupMapFromGid::getInstance ();
1028
+ if (map->has (ring_id)) {
1029
+ distributed::ProcessGroup* pg = map->get (ring_id);
1030
+ auto comm_context =
1031
+ static_cast <paddle::distributed::ProcessGroupNCCL*>(pg)
1032
+ ->GetOrCreateCommContext (place);
1033
+ dev_ctx =
1034
+ static_cast <phi::distributed::NCCLCommContext*>(comm_context)
1035
+ ->GetDevContext ();
1036
+ dev_ctx->SetCommContext (comm_context);
1037
+ } else {
1038
+ VLOG (3 ) << " ring_id " << ring_id
1039
+ << " not found in ProcessGroupMapFromGid " ;
1040
+ }
1041
+ }
1042
+ #endif
1013
1043
phi::KernelContext phi_kernel_context;
1014
1044
op_with_kernel->BuildPhiKernelContext (
1015
1045
*instr_node.InnerRuntimeContext ().get (),
1016
- const_cast <phi::DeviceContext*>(&instr_node. DeviceContext ()) ,
1046
+ dev_ctx ,
1017
1047
&phi_kernel_context);
1018
1048
1019
1049
(*kernel)(&phi_kernel_context);
1020
1050
} else {
1021
- VLOG (4 ) << " Run structure kernel: " << op->Type ();
1051
+ VLOG (2 ) << " Run structure kernel: " << op->Type ();
1022
1052
(*kernel)(instr_node.InnerExecutionContext ().get ());
1023
1053
}
1024
1054
} else { // fluid kernel
@@ -1148,7 +1178,7 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
1148
1178
}
1149
1179
1150
1180
void ProgramInterpreter::RunInstruction (const Instruction& instr_node) {
1151
- VLOG (5 ) << __func__ << " OP id:" << instr_node.Id ()
1181
+ VLOG (2 ) << __func__ << " OP id:" << instr_node.Id ()
1152
1182
<< " name:" << instr_node.OpBase ()->Type () << " type:"
1153
1183
<< (instr_node.KernelType () == OpFuncType::kCpuSync
1154
1184
? " kCpuSync"
@@ -1603,6 +1633,7 @@ bool ProgramInterpreter::HasLocalScope() const {
1603
1633
// KQueueSync Ops is 0, we choose Trace mode.
1604
1634
void ProgramInterpreter::TraceInstructionList (
1605
1635
const std::vector<Instruction>& vec_instr) {
1636
+ VLOG (2 ) << " [liyamei ProgramInterpreter] start TraceInstructionList" ;
1606
1637
unfinished_op_number_ = vec_instr.size ();
1607
1638
if (unfinished_op_number_ == 0 ) {
1608
1639
VLOG (4 ) << " No op to run, return" ;
0 commit comments