Skip to content

Commit 3b27bf0

Browse files
[PIR][Dy2St] Fix ReshapeOp cannot inplace in PIR mode (#69428)
--------- Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
1 parent a9f8221 commit 3b27bf0

File tree

5 files changed

+85
-34
lines changed

5 files changed

+85
-34
lines changed

paddle/fluid/eager/to_static/run_program_op_node.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,21 @@ inline void PirRunProgramAPI(
438438
VLOG(2) << "No interpretercore cache, so create a new interpretercore "
439439
"for program: "
440440
<< program_id;
441-
// Step 1. share input_vars & parameters into scope
441+
442+
// Step 1. Get no need buffer vars for inplace pass and gc
443+
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
444+
attrs.at("no_need_buffers"));
445+
const auto no_need_buffer_names =
446+
details::GetNameFromValue(no_need_buffer_values);
447+
const auto no_need_buffer_name_set = std::set<std::string>(
448+
no_need_buffer_names.begin(), no_need_buffer_names.end());
449+
// Step 2. share input_vars & parameters into scope
442450
details::ShareTensorsIntoScopeByValue(x, input_values, global_inner_scope);
443451
details::ShareTensorsIntoScopeByValue(
444452
params, param_values, global_inner_scope);
445-
// Step 2. create new interpretercore
446-
auto passed_kernel_program =
447-
paddle::framework::ApplyIrPass(forward_program.get(), place);
453+
// Step 3. create new interpretercore
454+
auto passed_kernel_program = paddle::framework::ApplyIrPass(
455+
forward_program.get(), place, no_need_buffer_name_set);
448456
interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache(
449457
std::move(passed_kernel_program),
450458
place,
@@ -453,7 +461,7 @@ inline void PirRunProgramAPI(
453461
global_inner_scope,
454462
place_hash_key,
455463
in_sot_mode);
456-
// Step 3. get all eager gc vars (skip_names = backward_inputs -
464+
// Step 4. get all eager gc vars (skip_names = backward_inputs -
457465
// no_need_buffers + outputs)
458466
std::vector<std::string> skip_names;
459467
// update interpretercore skip_gc_var
@@ -462,10 +470,6 @@ inline void PirRunProgramAPI(
462470
}
463471
auto skip_names_set =
464472
std::set<std::string>(skip_names.begin(), skip_names.end());
465-
auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>,
466-
attrs.at("no_need_buffers"));
467-
auto no_need_buffer_names =
468-
details::GetNameFromValue(no_need_buffer_values);
469473
for (auto &name : no_need_buffer_names) {
470474
VLOG(4) << "Find no need buffer vars with name:" << name;
471475
skip_names_set.erase(name);
@@ -990,7 +994,7 @@ inline void PirRunProgramGradAPI(
990994
VLOG(2) << "No interpretercore cache, so create a new interpretercore";
991995
// Step 1. share input_vars & parameters into scope
992996
auto passed_kernel_program =
993-
paddle::framework::ApplyIrPass(backward_program.get(), place);
997+
paddle::framework::ApplyIrPass(backward_program.get(), place, {});
994998

995999
const auto &new_block = passed_kernel_program->block();
9961000
passed_kernel_program = paddle::framework::ApplyRemoveShadowFeedPass(

paddle/fluid/framework/executor_cache.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,15 @@ bool TensorSortHelper(const paddle::Tensor &t1, const paddle::Tensor &t2) {
172172
return t1.name() < t2.name();
173173
}
174174

175-
std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program *program,
176-
phi::Place place) {
175+
std::unique_ptr<::pir::Program> ApplyIrPass(
176+
::pir::Program *program,
177+
phi::Place place,
178+
const std::set<std::string> &no_need_buffer_names) {
177179
auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program, place);
178180

179181
if (FLAGS_pir_apply_inplace_pass) {
180182
::pir::PassManager pm(::pir::IrContext::Instance(), 3);
181-
pm.AddPass(::pir::CreateInplacePass());
183+
pm.AddPass(::pir::CreateInplacePass(no_need_buffer_names));
182184
pm.Run(ir_res.get());
183185

184186
if (FLAGS_print_ir) {
@@ -314,7 +316,7 @@ std::unique_ptr<::pir::Program> ConstructForwardIrProgram(
314316
}
315317
auto program = TranslateLegacyProgramToProgram(local_program);
316318

317-
return ApplyIrPass(program.get(), place);
319+
return ApplyIrPass(program.get(), place, {});
318320
}
319321

320322
std::unique_ptr<::pir::Program> ConstructBackwardIrProgram(

paddle/fluid/framework/executor_cache.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,10 @@ std::shared_ptr<InterpreterCore> CreatePirInterpreterCoreInfoToCache(
162162
const int64_t& place_hash_key,
163163
bool used_for_sot);
164164

165-
std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program* program,
166-
phi::Place place);
165+
std::unique_ptr<::pir::Program> ApplyIrPass(
166+
::pir::Program* program,
167+
phi::Place place,
168+
const std::set<std::string>& no_need_buffer_names);
167169

168170
std::unique_ptr<::pir::Program> ApplyRemoveShadowFeedPass(
169171
const std::unique_ptr<::pir::Program> program,

paddle/fluid/pir/transforms/general/inplace_pass.cc

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ bool IsLastUser(const pir::Value& value,
7575
auto current_value = value;
7676
while (use_count_map.at(current_value) == 0) {
7777
if (inplace_map.count(current_value) == 0) {
78-
return false;
78+
return true;
7979
}
8080
current_value = inplace_map.at(current_value);
8181
}
82-
return true;
82+
return false;
8383
}
8484

8585
bool CanDoInplace(const std::unordered_set<pir::Value>& eager_dels,
@@ -238,10 +238,16 @@ bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) {
238238

239239
// NOTE(zhangbo): pd_op.feed's output and pd_op.fetch's input can not be eager
240240
// deleted.
241-
std::unordered_set<pir::Value> GetSkipDeletionValues(const pir::Block& block) {
241+
std::unordered_set<pir::Value> GetSkipDeletionValues(
242+
const pir::Block& block,
243+
const std::set<std::string>& no_need_buffer_values) {
242244
std::unordered_set<pir::Value> skip_dels;
243245
for (auto& op : block) {
244-
if (op.name() == "builtin.shadow_output") {
246+
if (op.name() == "builtin.shadow_output" &&
247+
no_need_buffer_values.count(op.attributes()
248+
.at("output_name")
249+
.dyn_cast<pir::StrAttribute>()
250+
.AsString()) == 0) {
245251
skip_dels.insert(op.operand_source(0));
246252
continue;
247253
}
@@ -291,14 +297,20 @@ void GetEagerDelValueOfOp(
291297
.AsString();
292298
}
293299

300+
if (upper_op_name == "builtin.shadow_output") {
301+
continue;
302+
}
303+
294304
for (size_t i = 0; i < op.num_operands(); ++i) {
295305
auto input = op.operand_source(i);
296-
if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input)) {
306+
if (skip_dels.count(input) > 0 || !input || !CanBeDeleted(input) ||
307+
IsNoNeedBuffer(&op, input)) {
297308
VLOG(6) << "The " << i << "-th input value of the Operation("
298309
<< upper_op_name << ") can not be deleted.";
299310
VLOG(8) << " -- skip dels: " << skip_dels.count(input);
300311
VLOG(8) << " -- value is null: " << !input;
301312
VLOG(8) << " -- can be deleted: " << !CanBeDeleted(input);
313+
VLOG(8) << " -- is no_need_buffer: " << IsNoNeedBuffer(&op, input);
302314
continue;
303315
}
304316
(*del_value_2_op)[input] = &op;
@@ -323,8 +335,10 @@ void GetEagerDelValueOfOp(
323335
}
324336

325337
std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
326-
GetEagerDeletionValues(const pir::Block& block) {
327-
std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues(block);
338+
GetEagerDeletionValues(const pir::Block& block,
339+
const std::set<std::string>& no_need_buffer_values) {
340+
std::unordered_set<pir::Value> skip_dels =
341+
GetSkipDeletionValues(block, no_need_buffer_values);
328342
std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
329343
GetEagerDelValueOfOp(block, skip_dels, &del_value_2_op);
330344
std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
@@ -336,13 +350,31 @@ GetEagerDeletionValues(const pir::Block& block) {
336350
}
337351

338352
std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
339-
const pir::Block& block) {
340-
const auto eager_dels = GetEagerDeletionValues(block);
341-
auto use_count_map = [](const pir::Block& block) {
353+
const pir::Block& block,
354+
const std::set<std::string>& no_need_buffer_values) {
355+
const auto eager_dels = GetEagerDeletionValues(block, no_need_buffer_values);
356+
357+
auto is_no_need_buffer = [&no_need_buffer_values](pir::Operation* op,
358+
pir::Value value) {
359+
if (auto shadow_output_op = op->dyn_cast<pir::ShadowOutputOp>()) {
360+
if (no_need_buffer_values.count(shadow_output_op.attributes()
361+
.at("output_name")
362+
.dyn_cast<pir::StrAttribute>()
363+
.AsString())) {
364+
return true;
365+
}
366+
}
367+
return IsNoNeedBuffer(op, value);
368+
};
369+
auto use_count_map = [&is_no_need_buffer](const pir::Block& block) {
342370
std::unordered_map<pir::Value, size_t> use_count_map;
343371
for (auto& op : block) {
344372
for (auto value : op.results()) {
345-
use_count_map[value] = value.use_count();
373+
size_t use_count = 0;
374+
for (auto it = value.use_begin(); it != value.use_end(); ++it) {
375+
use_count += is_no_need_buffer(it->owner(), value) ? 0 : 1;
376+
}
377+
use_count_map[value] = use_count;
346378
}
347379
}
348380
return use_count_map;
@@ -356,7 +388,8 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
356388
for (auto& op : block) {
357389
for (size_t i = 0; i < op.num_operands(); ++i) {
358390
visited_values.insert(op.operand_source(i));
359-
use_count_map[op.operand_source(i)]--;
391+
use_count_map[op.operand_source(i)] -=
392+
is_no_need_buffer(&op, op.operand_source(i)) ? 0 : 1;
360393
}
361394

362395
if (op.dialect()->name() != paddle::dialect::KernelDialect::name()) {
@@ -468,7 +501,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
468501
upper_op_name)) ||
469502
(visited_values.count(op.result(out_slot)) > 0) ||
470503
(!CanBeDeleted(op.result(out_slot))) ||
471-
IsLastUser(op.operand_source(in_slot), use_count_map, inplace_map) ||
504+
!IsLastUser(op.operand_source(in_slot), use_count_map, inplace_map) ||
472505
(std::find(used_external_values.begin(),
473506
used_external_values.end(),
474507
op.operand_source(in_slot)) !=
@@ -493,7 +526,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
493526
<< " -- result " << out_slot
494527
<< " visited: " << (visited_values.count(op.result(out_slot)) > 0);
495528
VLOG_IF(8, in_slot < op.num_operands())
496-
<< " -- operand " << in_slot << " has not user: "
529+
<< " -- operand " << in_slot << " is last user: "
497530
<< IsLastUser(
498531
op.operand_source(in_slot), use_count_map, inplace_map);
499532
break;
@@ -525,12 +558,17 @@ class InplacePass : public pir::Pass {
525558
public:
526559
InplacePass() : pir::Pass("inplace_pass", 3) {}
527560

561+
explicit InplacePass(const std::set<std::string>& no_need_buffer_values)
562+
: pir::Pass("inplace_pass", 3) {
563+
no_need_buffer_values_ = no_need_buffer_values;
564+
}
565+
528566
void Run(pir::Operation* op) override {
529567
int64_t num_rewrites_{0};
530568
for (size_t i = 0; i < op->num_regions(); ++i) {
531569
auto& region = op->region(i);
532570
for (auto& block : region) {
533-
auto inplace_ops = GetInplaceOps(block);
571+
auto inplace_ops = GetInplaceOps(block, no_need_buffer_values_);
534572

535573
for (const auto& kv : inplace_ops) {
536574
VLOG(6) << "Do inplace for: "
@@ -558,12 +596,16 @@ class InplacePass : public pir::Pass {
558596
}
559597
AddStatistics(num_rewrites_);
560598
}
599+
600+
private:
601+
std::set<std::string> no_need_buffer_values_;
561602
};
562603

563604
namespace pir {
564605

565-
std::unique_ptr<pir::Pass> CreateInplacePass() {
566-
return std::make_unique<InplacePass>();
606+
std::unique_ptr<pir::Pass> CreateInplacePass(
607+
const std::set<std::string>& no_need_buffer_values) {
608+
return std::make_unique<InplacePass>(no_need_buffer_values);
567609
}
568610

569611
} // namespace pir

paddle/fluid/pir/transforms/general/inplace_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace pir {
2121

2222
class Pass;
2323

24-
std::unique_ptr<Pass> CreateInplacePass();
24+
std::unique_ptr<Pass> CreateInplacePass(
25+
const std::set<std::string>& no_need_buffer_values = {});
2526

2627
} // namespace pir

0 commit comments

Comments
 (0)