Skip to content

Commit f00a06d

Browse files
authored
[IR] Refine BuildScope in phi_kernel_util (#55423)
* add code * fix bug * refine code * refine code * fix bug
1 parent 7f6d222 commit f00a06d

File tree

6 files changed

+245
-161
lines changed

6 files changed

+245
-161
lines changed

paddle/fluid/framework/new_executor/new_ir_interpreter.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,20 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
185185

186186
if (!is_build_) {
187187
LOG_FIRST_N(INFO, 1) << "New Executor is Running.";
188-
::ir::BuildScope(
189-
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
188+
::ir::BuildScope(*ir_program_->block(),
189+
InnerScope(),
190+
&value_2_var_name_,
191+
&variable_2_var_name_,
192+
&var_name_2_id_,
193+
&variable_list_);
190194

191195
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
192196
interpreter::BuildOpFuncList(place_,
193197
ir_program_->block(),
194198
&op_func_nodes,
195199
scope_,
196200
local_scope_,
197-
value_2_var_name_map_,
201+
value_2_var_name_,
198202
execution_config_);
199203
// SetFeedVarsInplaceSkip(feed_names);
200204
// convert vec func_list to graph
@@ -237,8 +241,12 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
237241
SetDeviceId(place_);
238242
if (!is_build_) {
239243
LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning.";
240-
::ir::BuildScope(
241-
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
244+
::ir::BuildScope(*ir_program_->block(),
245+
InnerScope(),
246+
&value_2_var_name_,
247+
&variable_2_var_name_,
248+
&var_name_2_id_,
249+
&variable_list_);
242250
BuildInstruction();
243251
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
244252
++instr_id) {
@@ -1526,13 +1534,8 @@ void NewIRInterpreter::BuildInstruction() {
15261534
++it) {
15271535
VLOG(0) << "Build Instruction for op: " << op_idx;
15281536
if ((*it)->dialect()->name() == "pd_kernel") {
1529-
vec_instruction_base_.emplace_back(
1530-
std::make_unique<PhiKernelInstruction>(op_idx++,
1531-
place_,
1532-
(*it),
1533-
scope_,
1534-
local_scope_,
1535-
value_2_var_name_map_));
1537+
vec_instruction_base_.emplace_back(std::make_unique<PhiKernelInstruction>(
1538+
op_idx++, place_, (*it), scope_, local_scope_, value_2_var_name_));
15361539
} else {
15371540
PADDLE_THROW(platform::errors::Unimplemented(
15381541
"Now only support pd_kernel dialect."));

paddle/fluid/framework/new_executor/new_ir_interpreter.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ class NewIRInterpreter : public InterpreterBaseImpl {
192192

193193
std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;
194194

195-
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_;
195+
std::unordered_map<::ir::Value, std::string> value_2_var_name_;
196+
std::unordered_map<const paddle::framework::Variable*, std::string>
197+
variable_2_var_name_;
198+
std::map<std::string, int> var_name_2_id_;
199+
std::vector<Variable*> variable_list_;
196200
};
197201

198202
} // namespace framework

paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,18 @@ class PhiKernelAdaptor {
5555

5656
void run_kernel_prog(ir::Program* program) {
5757
auto block = program->block();
58-
std::unordered_map<ir::Value, std::string> name_map;
59-
BuildScope(*block, scope_, nullptr, &name_map);
58+
std::unordered_map<ir::Value, std::string> value_2_var_name;
59+
std::unordered_map<const paddle::framework::Variable*, std::string>
60+
variable_2_var_name;
61+
std::map<std::string, int> var_name_2_id;
62+
std::vector<paddle::framework::Variable*> variable_list;
63+
64+
BuildScope(*block,
65+
scope_,
66+
&value_2_var_name,
67+
&variable_2_var_name,
68+
&var_name_2_id,
69+
&variable_list);
6070
ir::IrContext* ctx = ir::IrContext::Instance();
6171

6272
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
@@ -88,7 +98,8 @@ class PhiKernelAdaptor {
8898
phi::MetaTensor,
8999
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
90100
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
91-
false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx);
101+
false>(
102+
(*it), value_2_var_name, scope_, nullptr, op_yaml_info_parser, &ctx);
92103

93104
infer_meta_impl->infer_meta_(&ctx);
94105

@@ -108,12 +119,16 @@ class PhiKernelAdaptor {
108119
phi::TensorBase*,
109120
paddle::small_vector<const phi::TensorBase*>,
110121
paddle::small_vector<phi::TensorBase*>,
111-
true>(
112-
(*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx);
122+
true>((*it),
123+
value_2_var_name,
124+
scope_,
125+
nullptr,
126+
op_yaml_info_parser,
127+
&kernel_ctx);
113128
kernel_fn(&kernel_ctx);
114129

115130
auto out_value = (*it)->result(0);
116-
out_name = name_map[out_value];
131+
out_name = value_2_var_name[out_value];
117132
}
118133
}
119134

0 commit comments

Comments
 (0)