@@ -75,11 +75,11 @@ bool IsLastUser(const pir::Value& value,
75
75
auto current_value = value;
76
76
while (use_count_map.at (current_value) == 0 ) {
77
77
if (inplace_map.count (current_value) == 0 ) {
78
- return false ;
78
+ return true ;
79
79
}
80
80
current_value = inplace_map.at (current_value);
81
81
}
82
- return true ;
82
+ return false ;
83
83
}
84
84
85
85
bool CanDoInplace (const std::unordered_set<pir::Value>& eager_dels,
@@ -238,10 +238,16 @@ bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) {
238
238
239
239
// NOTE(zhangbo): pd_op.feed's output and pd_op.fetch's input can not be eager
240
240
// deleted.
241
- std::unordered_set<pir::Value> GetSkipDeletionValues (const pir::Block& block) {
241
+ std::unordered_set<pir::Value> GetSkipDeletionValues (
242
+ const pir::Block& block,
243
+ const std::set<std::string>& no_need_buffer_values) {
242
244
std::unordered_set<pir::Value> skip_dels;
243
245
for (auto & op : block) {
244
- if (op.name () == " builtin.shadow_output" ) {
246
+ if (op.name () == " builtin.shadow_output" &&
247
+ no_need_buffer_values.count (op.attributes ()
248
+ .at (" output_name" )
249
+ .dyn_cast <pir::StrAttribute>()
250
+ .AsString ()) == 0 ) {
245
251
skip_dels.insert (op.operand_source (0 ));
246
252
continue ;
247
253
}
@@ -291,14 +297,20 @@ void GetEagerDelValueOfOp(
291
297
.AsString ();
292
298
}
293
299
300
+ if (upper_op_name == " builtin.shadow_output" ) {
301
+ continue ;
302
+ }
303
+
294
304
for (size_t i = 0 ; i < op.num_operands (); ++i) {
295
305
auto input = op.operand_source (i);
296
- if (skip_dels.count (input) > 0 || !input || !CanBeDeleted (input)) {
306
+ if (skip_dels.count (input) > 0 || !input || !CanBeDeleted (input) ||
307
+ IsNoNeedBuffer (&op, input)) {
297
308
VLOG (6 ) << " The " << i << " -th input value of the Operation("
298
309
<< upper_op_name << " ) can not be deleted." ;
299
310
VLOG (8 ) << " -- skip dels: " << skip_dels.count (input);
300
311
VLOG (8 ) << " -- value is null: " << !input;
301
312
VLOG (8 ) << " -- can be deleted: " << !CanBeDeleted (input);
313
+ VLOG (8 ) << " -- is no_need_buffer: " << IsNoNeedBuffer (&op, input);
302
314
continue ;
303
315
}
304
316
(*del_value_2_op)[input] = &op;
@@ -323,8 +335,10 @@ void GetEagerDelValueOfOp(
323
335
}
324
336
325
337
std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
326
- GetEagerDeletionValues (const pir::Block& block) {
327
- std::unordered_set<pir::Value> skip_dels = GetSkipDeletionValues (block);
338
+ GetEagerDeletionValues (const pir::Block& block,
339
+ const std::set<std::string>& no_need_buffer_values) {
340
+ std::unordered_set<pir::Value> skip_dels =
341
+ GetSkipDeletionValues (block, no_need_buffer_values);
328
342
std::unordered_map<pir::Value, pir::Operation*> del_value_2_op;
329
343
GetEagerDelValueOfOp (block, skip_dels, &del_value_2_op);
330
344
std::unordered_map<pir::Operation*, std::unordered_set<pir::Value>>
@@ -336,13 +350,31 @@ GetEagerDeletionValues(const pir::Block& block) {
336
350
}
337
351
338
352
std::unordered_map<pir::Operation*, std::string> GetInplaceOps (
339
- const pir::Block& block) {
340
- const auto eager_dels = GetEagerDeletionValues (block);
341
- auto use_count_map = [](const pir::Block& block) {
353
+ const pir::Block& block,
354
+ const std::set<std::string>& no_need_buffer_values) {
355
+ const auto eager_dels = GetEagerDeletionValues (block, no_need_buffer_values);
356
+
357
+ auto is_no_need_buffer = [&no_need_buffer_values](pir::Operation* op,
358
+ pir::Value value) {
359
+ if (auto shadow_output_op = op->dyn_cast <pir::ShadowOutputOp>()) {
360
+ if (no_need_buffer_values.count (shadow_output_op.attributes ()
361
+ .at (" output_name" )
362
+ .dyn_cast <pir::StrAttribute>()
363
+ .AsString ())) {
364
+ return true ;
365
+ }
366
+ }
367
+ return IsNoNeedBuffer (op, value);
368
+ };
369
+ auto use_count_map = [&is_no_need_buffer](const pir::Block& block) {
342
370
std::unordered_map<pir::Value, size_t > use_count_map;
343
371
for (auto & op : block) {
344
372
for (auto value : op.results ()) {
345
- use_count_map[value] = value.use_count ();
373
+ size_t use_count = 0 ;
374
+ for (auto it = value.use_begin (); it != value.use_end (); ++it) {
375
+ use_count += is_no_need_buffer (it->owner (), value) ? 0 : 1 ;
376
+ }
377
+ use_count_map[value] = use_count;
346
378
}
347
379
}
348
380
return use_count_map;
@@ -356,7 +388,8 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
356
388
for (auto & op : block) {
357
389
for (size_t i = 0 ; i < op.num_operands (); ++i) {
358
390
visited_values.insert (op.operand_source (i));
359
- use_count_map[op.operand_source (i)]--;
391
+ use_count_map[op.operand_source (i)] -=
392
+ is_no_need_buffer (&op, op.operand_source (i)) ? 0 : 1 ;
360
393
}
361
394
362
395
if (op.dialect ()->name () != paddle::dialect::KernelDialect::name ()) {
@@ -468,7 +501,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
468
501
upper_op_name)) ||
469
502
(visited_values.count (op.result (out_slot)) > 0 ) ||
470
503
(!CanBeDeleted (op.result (out_slot))) ||
471
- IsLastUser (op.operand_source (in_slot), use_count_map, inplace_map) ||
504
+ ! IsLastUser (op.operand_source (in_slot), use_count_map, inplace_map) ||
472
505
(std::find (used_external_values.begin (),
473
506
used_external_values.end (),
474
507
op.operand_source (in_slot)) !=
@@ -493,7 +526,7 @@ std::unordered_map<pir::Operation*, std::string> GetInplaceOps(
493
526
<< " -- result " << out_slot
494
527
<< " visited: " << (visited_values.count (op.result (out_slot)) > 0 );
495
528
VLOG_IF (8 , in_slot < op.num_operands ())
496
- << " -- operand " << in_slot << " has not user: "
529
+ << " -- operand " << in_slot << " is last user: "
497
530
<< IsLastUser (
498
531
op.operand_source (in_slot), use_count_map, inplace_map);
499
532
break ;
@@ -525,12 +558,17 @@ class InplacePass : public pir::Pass {
525
558
public:
526
559
InplacePass () : pir::Pass(" inplace_pass" , 3 ) {}
527
560
561
+ explicit InplacePass (const std::set<std::string>& no_need_buffer_values)
562
+ : pir::Pass(" inplace_pass" , 3 ) {
563
+ no_need_buffer_values_ = no_need_buffer_values;
564
+ }
565
+
528
566
void Run (pir::Operation* op) override {
529
567
int64_t num_rewrites_{0 };
530
568
for (size_t i = 0 ; i < op->num_regions (); ++i) {
531
569
auto & region = op->region (i);
532
570
for (auto & block : region) {
533
- auto inplace_ops = GetInplaceOps (block);
571
+ auto inplace_ops = GetInplaceOps (block, no_need_buffer_values_ );
534
572
535
573
for (const auto & kv : inplace_ops) {
536
574
VLOG (6 ) << " Do inplace for: "
@@ -558,12 +596,16 @@ class InplacePass : public pir::Pass {
558
596
}
559
597
AddStatistics (num_rewrites_);
560
598
}
599
+
600
+ private:
601
+ std::set<std::string> no_need_buffer_values_;
561
602
};
562
603
563
604
namespace pir {
564
605
565
- std::unique_ptr<pir::Pass> CreateInplacePass () {
566
- return std::make_unique<InplacePass>();
606
+ std::unique_ptr<pir::Pass> CreateInplacePass (
607
+ const std::set<std::string>& no_need_buffer_values) {
608
+ return std::make_unique<InplacePass>(no_need_buffer_values);
567
609
}
568
610
569
611
} // namespace pir
0 commit comments