@@ -60,18 +60,26 @@ const std::string StringizeDownstreamMap(
60
60
return oss.str ();
61
61
}
62
62
63
+ DependencyBuilder::DependencyBuilder ()
64
+ : is_build_(false ), instructions_(nullptr ) {
65
+ op_downstream_map_ = std::make_shared<std::map<size_t , std::set<size_t >>>();
66
+ op_happens_before_ = std::make_shared<std::vector<std::vector<bool >>>();
67
+ }
68
+
63
69
const std::map<size_t , std::set<size_t >>& DependencyBuilder::Build (
64
70
const std::vector<Instruction>& instructions) {
65
71
if (is_build_) {
66
- return op_downstream_map_;
72
+ return * op_downstream_map_;
67
73
}
68
74
75
+ std::tie (op_downstream_map_, op_happens_before_) = GetDependency ();
76
+
69
77
instructions_ = &instructions;
70
78
op_num_ = instructions_->size ();
71
79
72
80
ops_before_.assign (op_num_, {});
73
81
ops_behind_.assign (op_num_, {});
74
- op_happens_before_. assign (op_num_, std::vector<bool >(op_num_, false ));
82
+ op_happens_before_-> assign (op_num_, std::vector<bool >(op_num_, false ));
75
83
76
84
BuildDownstreamMap ();
77
85
VLOG (6 ) << " Finish BuildDownstreamMap" ;
@@ -97,13 +105,24 @@ const std::map<size_t, std::set<size_t>>& DependencyBuilder::Build(
97
105
VLOG (6 ) << " Finish AddDependencyForReadOp" ;
98
106
99
107
VLOG (6 ) << " Finish build dependency" ;
100
- VLOG (8 ) << " downstream count: " << CountDownstreamMap (op_downstream_map_);
108
+ VLOG (8 ) << " downstream count: " << CountDownstreamMap (* op_downstream_map_);
101
109
VLOG (8 ) << " downstream_map: " << std::endl
102
- << StringizeDownstreamMap (op_downstream_map_);
110
+ << StringizeDownstreamMap (* op_downstream_map_);
103
111
104
112
is_build_ = true ;
105
113
106
- return op_downstream_map_;
114
+ return *op_downstream_map_;
115
+ }
116
+
117
+ std::tuple<std::shared_ptr<std::map<size_t , std::set<size_t >>>,
118
+ std::shared_ptr<std::vector<std::vector<bool >>>>
119
+ DependencyBuilder::GetDependency () const {
120
+ return std::make_tuple (op_downstream_map_, op_happens_before_);
121
+ }
122
+
123
+ void DependencyBuilder::ShareDependencyFrom (const DependencyBuilder& src) {
124
+ std::tie (op_downstream_map_, op_happens_before_) = src.GetDependency ();
125
+ is_build_ = true ;
107
126
}
108
127
109
128
const std::map<size_t , std::set<size_t >>& DependencyBuilder::OpDownstreamMap ()
@@ -113,7 +132,7 @@ const std::map<size_t, std::set<size_t>>& DependencyBuilder::OpDownstreamMap()
113
132
true ,
114
133
phi::errors::Unavailable (
115
134
" DependencyBuilder is not yet built, call Build() firstly." ));
116
- return op_downstream_map_;
135
+ return * op_downstream_map_;
117
136
}
118
137
119
138
void DependencyBuilder::AddDependencyForCoalesceTensorOp () {
@@ -268,8 +287,8 @@ void DependencyBuilder::AddDependencyForRandomOp() {
268
287
void DependencyBuilder::AddDependencyForReadOp () {
269
288
std::vector<bool > is_startup_ops (op_num_, true );
270
289
for (size_t op_idx = 0 ; op_idx < op_num_; ++op_idx) {
271
- auto it = op_downstream_map_. find (op_idx);
272
- if (it != op_downstream_map_. end ()) {
290
+ auto it = op_downstream_map_-> find (op_idx);
291
+ if (it != op_downstream_map_-> end ()) {
273
292
for (size_t downstream_op_idx : it->second ) {
274
293
is_startup_ops[downstream_op_idx] = false ;
275
294
}
@@ -320,8 +339,7 @@ void DependencyBuilder::AddDownstreamOp(size_t prior_op_idx,
320
339
posterior_op_idx,
321
340
posterior_op_idx,
322
341
prior_op_idx));
323
-
324
- std::set<size_t >& downstream_ops = op_downstream_map_[prior_op_idx];
342
+ std::set<size_t >& downstream_ops = (*op_downstream_map_)[prior_op_idx];
325
343
// NOTE(Ruibiao): Here the downstream map shrinking is best-effort, therefore
326
344
// ShrinkDownstreamMap after BuildDownstreamMap is still helpful. For example,
327
345
// a->c will not be shrinked in the following case: AddDownstreamOp(a, b) ->
@@ -342,8 +360,8 @@ void DependencyBuilder::AddDownstreamOp(size_t prior_op_idx,
342
360
343
361
auto update_op_happen_before = [this ](size_t prior_op_idx,
344
362
size_t posterior_op_idx) {
345
- if (!op_happens_before_[prior_op_idx][posterior_op_idx]) {
346
- op_happens_before_[prior_op_idx][posterior_op_idx] = true ;
363
+ if (!(* op_happens_before_) [prior_op_idx][posterior_op_idx]) {
364
+ (* op_happens_before_) [prior_op_idx][posterior_op_idx] = true ;
347
365
ops_before_[posterior_op_idx].push_back (prior_op_idx);
348
366
ops_behind_[prior_op_idx].push_back (posterior_op_idx);
349
367
}
@@ -377,8 +395,8 @@ void DependencyBuilder::BuildDownstreamMap() {
377
395
std::map<size_t , size_t >(); // # map from variable to recent write op.
378
396
auto op2dependences =
379
397
std::map<size_t ,
380
- std::set<size_t >>(); // # map from op to the dependence list,
381
- // op must run after the dependence.
398
+ std::set<size_t >>(); // # map from op to the dependence list,
399
+ // op must run after the dependence.
382
400
std::set<size_t >
383
401
remove_duplicate; // remove the duplicate between inputs and outputs
384
402
@@ -497,15 +515,15 @@ void DependencyBuilder::ShrinkDownstreamMap() {
497
515
// shrink, find the downstream op that has no other op in the
498
516
// downstream list happens before it
499
517
for (size_t i = 0 ; i < op_num_; ++i) {
500
- if (op_downstream_map_. find (i) == op_downstream_map_. end ()) {
518
+ if (op_downstream_map_-> find (i) == op_downstream_map_-> end ()) {
501
519
continue ;
502
520
}
503
521
504
522
std::set<size_t > minumum_nexts;
505
- for (size_t item : op_downstream_map_. at (i)) {
523
+ for (size_t item : op_downstream_map_-> at (i)) {
506
524
bool not_after_any = true ;
507
525
// find the op that is not executed after any
508
- for (size_t other_item : op_downstream_map_. at (i)) {
526
+ for (size_t other_item : op_downstream_map_-> at (i)) {
509
527
if (OpHappensBefore (other_item, item)) {
510
528
VLOG (8 ) << " happens_before: " << other_item << " ->" << item
511
529
<< " , so skip " << item;
@@ -520,12 +538,12 @@ void DependencyBuilder::ShrinkDownstreamMap() {
520
538
}
521
539
// NOTE(Ruibiao): op_happens_before will not be changed when shrink
522
540
// dowstream map
523
- op_downstream_map_. at (i) = minumum_nexts;
541
+ (*op_downstream_map_)[i] = minumum_nexts;
524
542
}
525
543
VLOG (8 ) << " Finish shrink downstream map" ;
526
- VLOG (8 ) << " downstream count: " << CountDownstreamMap (op_downstream_map_);
544
+ VLOG (8 ) << " downstream count: " << CountDownstreamMap (* op_downstream_map_);
527
545
VLOG (8 ) << " downstream_map: " << std::endl
528
- << StringizeDownstreamMap (op_downstream_map_);
546
+ << StringizeDownstreamMap (* op_downstream_map_);
529
547
}
530
548
531
549
// / ======================== ///
0 commit comments