Skip to content

Commit a5b9a74

Browse files
authored
[Inference] Predictor support pir and new executor (#58452)
* support translate to pir * fix * update * add ut * disable inference_op_replace_pass * support inplace_pass * fix * fix * update * add replace_fetch_with_shadow_output_pass * fix comment * update * update
1 parent b4e7c1a commit a5b9a74

19 files changed

+301
-160
lines changed

paddle/fluid/framework/naive_executor.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ void NaiveExecutor::PrepareInterpreterCore(
6060
place_, program_desc.Block(0), scope, execution_config);
6161
}
6262

63+
void NaiveExecutor::PrepareInterpreterCore(
64+
Scope *scope,
65+
const ::pir::Program &pir_program,
66+
const framework::interpreter::ExecutionConfig &execution_config) {
67+
interpreter_core_ =
68+
std::make_unique<framework::InterpreterCore>(place_,
69+
std::vector<std::string>{},
70+
pir_program.block(),
71+
scope,
72+
execution_config);
73+
}
74+
6375
void NaiveExecutor::RunInterpreterCore(
6476
const std::vector<std::string> &feed_names, bool need_fetch) {
6577
platform::ScopedFlushDenormal flush;

paddle/fluid/framework/naive_executor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
3030
#include "paddle/fluid/framework/new_executor/interpretercore.h"
3131

32+
#include "paddle/pir/core/program.h"
33+
3234
namespace paddle {
3335
namespace framework {
3436

@@ -61,6 +63,12 @@ class NaiveExecutor {
6163
const framework::interpreter::ExecutionConfig& execution_config =
6264
framework::interpreter::ExecutionConfig{});
6365

66+
void PrepareInterpreterCore(
67+
Scope* scope,
68+
const ::pir::Program& pir_program,
69+
const framework::interpreter::ExecutionConfig& execution_config =
70+
framework::interpreter::ExecutionConfig{});
71+
6472
// Create variables before head.
6573
// Create parameters if persistable is true, or create the temporary variables
6674
// instead.

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ const paddle::framework::Variable* GetVariableByName(
12071207
return nullptr;
12081208
}
12091209

1210-
std::vector<std::string> GetOriginInputNames(std::string op_name) {
1210+
std::vector<std::string> GetOriginInputNames(const std::string& op_name) {
12111211
std::vector<std::string> ret;
12121212
pir::IrContext* ctx = pir::IrContext::Instance();
12131213
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
@@ -1220,7 +1220,7 @@ std::vector<std::string> GetOriginInputNames(std::string op_name) {
12201220
return ret;
12211221
}
12221222

1223-
std::vector<std::string> GetOriginOutputNames(std::string op_name) {
1223+
std::vector<std::string> GetOriginOutputNames(const std::string& op_name) {
12241224
std::vector<std::string> ret;
12251225
pir::IrContext* ctx = pir::IrContext::Instance();
12261226
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);

paddle/fluid/framework/new_executor/interpreter/interpreter_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ const paddle::framework::Variable* GetVariableByName(
126126
const std::unordered_map<const paddle::framework::Variable*, std::string>&
127127
variable_2_var_name);
128128

129-
std::vector<std::string> GetOriginInputNames(std::string op_name);
129+
std::vector<std::string> GetOriginInputNames(const std::string& op_name);
130130

131-
std::vector<std::string> GetOriginOutputNames(std::string op_name);
131+
std::vector<std::string> GetOriginOutputNames(const std::string& op_name);
132132

133133
void PrintValuesAndVariables(
134134
const pir::Block& block,

paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ std::shared_ptr<ValueExecutionInfo> ValueExecutionInfo::NewChild(Scope* scope) {
5858
return info;
5959
}
6060

61-
void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) {
61+
void ValueExecutionInfo::Add(::pir::Value value, const std::string& var_name) {
6262
auto* var = scope_->FindVar(var_name);
6363
PADDLE_ENFORCE_NOT_NULL(
6464
var, platform::errors::NotFound("Cannot find %s in scope.", var_name));
@@ -84,8 +84,8 @@ void ValueExecutionInfo::Add(::pir::Value value, std::string var_name) {
8484
}
8585

8686
void ValueExecutionInfo::Rename(pir::Value value,
87-
std::string new_name,
88-
std::string orig_name) {
87+
const std::string& new_name,
88+
const std::string& orig_name) {
8989
value_2_var_name_[value] = new_name;
9090

9191
for (auto kv : value_2_var_name_) {
@@ -344,9 +344,7 @@ void HandleForSpecialOp(pir::Operation* op,
344344
auto value = op->result(0);
345345

346346
value_exe_info->Add(value, fetch_var_name);
347-
}
348-
349-
if (op_name == "pd_op.feed" || op_name == "pd_op.data") {
347+
} else if (op_name == "pd_op.feed" || op_name == "pd_op.data") {
350348
VLOG(6) << "Handle for" << op_name;
351349
auto value = op->result(0);
352350
VLOG(6) << "link feed output to feed in variable"
@@ -360,9 +358,7 @@ void HandleForSpecialOp(pir::Operation* op,
360358
"The variable %s shoud exist", name));
361359

362360
value_exe_info->Add(value, name);
363-
}
364-
365-
if (op_name == "builtin.combine") {
361+
} else if (op_name == "builtin.combine") {
366362
auto out_value = op->result(0);
367363

368364
Variable* var = nullptr;
@@ -386,9 +382,7 @@ void HandleForSpecialOp(pir::Operation* op,
386382
tensor_array->emplace_back(
387383
value_exe_info->GetScope()->FindVar(value_2_var_name.at(value)));
388384
}
389-
}
390-
391-
if (op_name == "builtin.set_parameter") {
385+
} else if (op_name == "builtin.set_parameter") {
392386
VLOG(6) << "Handle for builtin.set_parameter:";
393387
auto param_name = op->attributes()
394388
.at("parameter_name")
@@ -413,8 +407,7 @@ void HandleForSpecialOp(pir::Operation* op,
413407
}
414408

415409
value_exe_info->Rename(value, param_name, orig_name);
416-
}
417-
if (op_name.compare(pir::ShadowOutputOp::name()) == 0) {
410+
} else if (op_name == "builtin.shadow_output") {
418411
VLOG(6) << "Handle for builtin.shadow_ouptut";
419412
auto var_name = op->attributes()
420413
.at("output_name")
@@ -433,9 +426,7 @@ void HandleForSpecialOp(pir::Operation* op,
433426
VLOG(8) << "var " << orig_name << " has been renamed to " << var_name;
434427

435428
value_exe_info->Rename(value, var_name, orig_name);
436-
}
437-
438-
if (op_name == "builtin.get_parameter") {
429+
} else if (op_name == "builtin.get_parameter") {
439430
VLOG(6) << "Handle for builtin.get_parameter:";
440431
auto param_name = op->attributes()
441432
.at("parameter_name")
@@ -444,9 +435,7 @@ void HandleForSpecialOp(pir::Operation* op,
444435
auto value = op->result(0);
445436

446437
value_exe_info->Add(value, param_name);
447-
}
448-
449-
if (op_name == "builtin.slice") {
438+
} else if (op_name == "builtin.slice") {
450439
VLOG(6) << "Handle for builtin.slice";
451440
auto out_value = op->result(0);
452441
auto in_value = op->operand_source(0);
@@ -471,9 +460,7 @@ void HandleForSpecialOp(pir::Operation* op,
471460
std::string var_name =
472461
value_exe_info->GetVar2VarName().at(variable_array[index]);
473462
value_exe_info->AddValue2VarName(out_value, var_name);
474-
}
475-
476-
if (op_name == "builtin.split") {
463+
} else if (op_name == "builtin.split") {
477464
VLOG(6) << "Handle for builtin.split";
478465
auto in_value = op->operand_source(0);
479466
PADDLE_ENFORCE_EQ(value_exe_info->GetValue2VarName().count(in_value),
@@ -497,17 +484,13 @@ void HandleForSpecialOp(pir::Operation* op,
497484
value_exe_info->GetVar2VarName().at(variable_array[idx]);
498485
value_exe_info->AddValue2VarName(out_value, var_name);
499486
}
500-
}
501-
502-
if (op_name == "pd_op.if") {
487+
} else if (op_name == "pd_op.if") {
503488
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
504489
for (size_t i = 0; i < if_op->num_results(); ++i) {
505490
auto if_op_out_value = if_op->result(i);
506491
BuildValue(if_op_out_value, var_name_prefix, value_exe_info);
507492
}
508-
}
509-
510-
if (op_name == "pd_op.while") {
493+
} else if (op_name == "pd_op.while") {
511494
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
512495

513496
for (size_t i = 0; i < while_op->num_results(); ++i) {

paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ class ValueExecutionInfo {
5656

5757
Scope* GetScope() const { return scope_; }
5858

59-
void Add(::pir::Value value, std::string var_name);
59+
void Add(::pir::Value value, const std::string& var_name);
6060

61-
void Rename(pir::Value value, std::string new_name, std::string orig_name);
61+
void Rename(pir::Value value,
62+
const std::string& new_name,
63+
const std::string& orig_name);
6264

6365
int GetIdByName(const std::string& name) const;
6466

paddle/fluid/framework/new_executor/pir_interpreter.cc

Lines changed: 35 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ paddle::framework::FetchList PirInterpreter::Run(
10701070

10711071
// Run
10721072
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
1073+
execution_config_.used_for_inference ||
10731074
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
10741075
(sync_op_num_ == 0))) {
10751076
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
@@ -1085,6 +1086,7 @@ paddle::framework::FetchList PirInterpreter::Run(
10851086
is_shared_results_build_ = true;
10861087
} else {
10871088
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
1089+
execution_config_.used_for_inference ||
10881090
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
10891091
(sync_op_num_ == 0))) {
10901092
TraceRunImpl();
@@ -1096,39 +1098,20 @@ paddle::framework::FetchList PirInterpreter::Run(
10961098
if (HasLocalScope()) {
10971099
ClearLoDTensorArrayInLocalScope();
10981100
}
1101+
10991102
// return Fetch Tensors
11001103
Scope* inner_scope = InnerScope();
1101-
if (FLAGS_enable_new_ir_in_executor) {
1102-
framework::FetchList fetch_res;
1103-
1104-
if (need_fetch) {
1105-
for (auto& var_name : fetch_var_names_) {
1106-
auto* var = inner_scope->FindVar(var_name);
1107-
VLOG(4) << "fetch " << var_name << "[" << var << "]";
1108-
fetch_res.push_back(var->Get<phi::DenseTensor>());
1109-
}
1110-
}
1111-
1112-
VLOG(4) << "get fetch list size: " << fetch_res.size();
1113-
return fetch_res;
1114-
} else {
1115-
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
1116-
if (fetch_var) {
1117-
auto fetch_list =
1118-
std::move(*fetch_var->GetMutable<framework::FetchList>());
1119-
#ifdef PADDLE_WITH_CUDA
1120-
if (platform::IsCUDAGraphCapturing()) {
1121-
PADDLE_ENFORCE_EQ(fetch_list.empty(),
1122-
true,
1123-
platform::errors::InvalidArgument(
1124-
"Cannot fetch data when using CUDA Graph."));
1125-
}
1126-
#endif
1127-
return fetch_list;
1128-
} else {
1129-
return {};
1104+
framework::FetchList fetch_res;
1105+
if (need_fetch) {
1106+
for (auto& var_name : fetch_var_names_) {
1107+
auto* var = inner_scope->FindVar(var_name);
1108+
VLOG(4) << "fetch " << var_name << "[" << var << "]";
1109+
fetch_res.push_back(var->Get<phi::DenseTensor>());
11301110
}
11311111
}
1112+
1113+
VLOG(4) << "get fetch list size: " << fetch_res.size();
1114+
return fetch_res;
11321115
}
11331116

11341117
FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
@@ -1161,6 +1144,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
11611144

11621145
// Run
11631146
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
1147+
execution_config_.used_for_inference ||
11641148
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
11651149
(sync_op_num_ == 0))) {
11661150
LOG_FIRST_N(INFO, 1) << "New ir interpreter is running in BetaRun mode "
@@ -1176,6 +1160,7 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
11761160
is_shared_results_build_ = true;
11771161
} else {
11781162
if (FLAGS_enable_new_ir_in_executor_trace_run || nccl_op_num_ > 1 ||
1163+
execution_config_.used_for_inference ||
11791164
((execution_config_.used_for_jit || execution_config_.used_for_cinn) &&
11801165
(sync_op_num_ == 0))) {
11811166
TraceRunImpl();
@@ -1187,38 +1172,21 @@ FetchList PirInterpreter::Run(const std::vector<std::string>& feed_names,
11871172
if (HasLocalScope()) {
11881173
ClearLoDTensorArrayInLocalScope();
11891174
}
1190-
// return Fetch Tensors
1191-
Scope* inner_scope = InnerScope();
1192-
if (FLAGS_enable_new_ir_in_executor) {
1193-
framework::FetchList fetch_res;
1194-
1195-
if (need_fetch) {
1196-
for (auto& var_name : fetch_var_names_) {
1197-
auto* var = inner_scope->FindVar(var_name);
1198-
VLOG(4) << "fetch " << var_name << "[" << var << "]";
1199-
fetch_res.push_back(var->Get<phi::DenseTensor>());
1200-
}
1175+
1176+
framework::FetchList fetch_res;
1177+
if (need_fetch) {
1178+
// return Fetch Tensors
1179+
Scope* inner_scope = InnerScope();
1180+
1181+
for (auto& var_name : fetch_var_names_) {
1182+
auto* var = inner_scope->FindVar(var_name);
1183+
VLOG(4) << "fetch " << var_name << "[" << var << "]";
1184+
fetch_res.push_back(var->Get<phi::DenseTensor>());
12011185
}
1186+
12021187
VLOG(4) << "get fetch list size: " << fetch_res.size();
1203-
return fetch_res;
1204-
} else {
1205-
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
1206-
if (fetch_var && need_fetch) {
1207-
auto fetch_list =
1208-
std::move(*fetch_var->GetMutable<framework::FetchList>());
1209-
#ifdef PADDLE_WITH_CUDA
1210-
if (platform::IsCUDAGraphCapturing()) {
1211-
PADDLE_ENFORCE_EQ(fetch_list.empty(),
1212-
true,
1213-
platform::errors::InvalidArgument(
1214-
"Cannot fetch data when using CUDA Graph."));
1215-
}
1216-
#endif
1217-
return fetch_list;
1218-
} else {
1219-
return {};
1220-
}
12211188
}
1189+
return fetch_res;
12221190
}
12231191

12241192
void PirInterpreter::TraceRunImpl() {
@@ -1437,10 +1405,11 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
14371405
platform::RecordEvent instruction_event(
14381406
instr_node->Name(), platform::TracerEventType::Operator, 1);
14391407

1440-
SetDeviceId(instr_node->DeviceContext().GetPlace());
1408+
auto cur_place = instr_node->DeviceContext().GetPlace();
1409+
SetDeviceId(cur_place);
14411410

14421411
try {
1443-
instr_node->WaitEvent(place_);
1412+
instr_node->WaitEvent(cur_place);
14441413
VLOG(4) << "begin to run op " << instr_node->Name();
14451414
VLOG(4) << "begin: " << __func__ << " OP id:" << instr_node->Id()
14461415
<< " name:" << instr_node->Name() << " type:"
@@ -1450,7 +1419,8 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
14501419
? "kGpuSync"
14511420
: "kGpuAsync"))
14521421
<< " runs on " << platform::GetCurrentThreadName();
1453-
VLOG(4) << place_ << " Before:"
1422+
1423+
VLOG(4) << cur_place << " Before:"
14541424
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
14551425
if (!instr_node->IsArtificial()) {
14561426
instr_node->Run();
@@ -1472,14 +1442,15 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) {
14721442
? "kGpuSync"
14731443
: "kGpuAsync"))
14741444
<< " runs on " << platform::GetCurrentThreadName();
1475-
VLOG(4) << place_ << " After:"
1445+
1446+
VLOG(4) << cur_place << " After:"
14761447
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
14771448
CheckGC(instr_node);
14781449
VLOG(4) << "done CheckGC";
1479-
memory::LogDeviceMemoryStats(place_, instr_node->Name());
1450+
memory::LogDeviceMemoryStats(cur_place, instr_node->Name());
14801451
}
14811452
VLOG(5) << "after run kernel";
1482-
instr_node->RecordEvent(place_);
1453+
instr_node->RecordEvent(cur_place);
14831454
} catch (platform::EnforceNotMet& ex) {
14841455
auto* op = instr_node->Operation();
14851456
const std::vector<std::string> op_callstack_attr =

0 commit comments

Comments
 (0)