20
20
namespace paddle {
21
21
namespace framework {
22
22
23
- namespace {
24
-
25
- /*
26
- * Parse the var_ids that need to be associated with an event.
27
- * The caller should guarantee front_op and back_op satisfy the
28
- * following conditions:
29
- * 1. kQueueAsync -> kQueueAsync
30
- * 2. kQueueAsync -> kQueueSync
31
- *
32
- * For example: matmul(gpu) -> out_var -> memcpy_d2h
33
- * out_var should be associated with an event.
34
- */
35
- std::vector<size_t > ParseEventVarIds (const Instruction& cur_instr,
36
- const Instruction& next_instr) {
37
- std::unordered_set<size_t > unique_var_ids;
38
- for (auto & item : cur_instr.output_index_ ) {
39
- unique_var_ids.insert (item.second .begin (), item.second .end ());
40
- }
41
-
42
- std::vector<size_t > new_event_var_ids;
43
- for (auto & item : next_instr.input_index_ ) {
44
- for (auto var_id : item.second ) {
45
- if (unique_var_ids.count (var_id) > 0 ) {
46
- new_event_var_ids.push_back (var_id);
47
- }
48
- }
49
- }
50
- return new_event_var_ids;
51
- }
52
-
53
- void AssociateInputWithEvents (
54
- const platform::Place& place, const std::vector<size_t >& new_event_var_id,
55
- Instruction* next_instr,
56
- std::map<size_t , std::shared_ptr<platform::DeviceEvent>>* var_id2event,
57
- bool is_sync) {
58
- for (auto var_id : new_event_var_id) {
59
- if (var_id2event->count (var_id) == 0 ) {
60
- auto device_event = std::make_shared<platform::DeviceEvent>(
61
- place, platform::GenerateDeviceEventFlag ());
62
- var_id2event->emplace (var_id, std::move (device_event));
63
- }
64
- // Add events for next_instr.inputs
65
- next_instr->intput_events_ .emplace_back (var_id, var_id2event->at (var_id),
66
- is_sync);
67
- }
68
- }
69
-
70
- void ParseDirectAndEventRunOps (
71
- const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
72
- const std::vector<size_t >& downstream_ops, size_t op_index,
73
- std::map<size_t , std::shared_ptr<platform::DeviceEvent>>* var_id2event,
74
- std::vector<Instruction>* instructions) {
75
- auto & op_func_type = op_func_nodes[op_index].type_ ;
76
- auto & cur_instr = instructions->at (op_index);
77
- auto & next_instruction = cur_instr.next_instruction_ ;
78
-
79
- if (op_func_type == OpFuncType::kQueueSync ) {
80
- // all downstream ops of kQueueSync can directly run, such as CPU -> Any
81
- next_instruction.direct_run_ = downstream_ops;
82
- } else { // kQueueAsync
83
- std::vector<size_t > event_var_ids;
84
- for (auto next_op_id : downstream_ops) {
85
- auto & next_instr = instructions->at (next_op_id);
86
- // case 1: GPU -> GPU(same stream)
87
- if (cur_instr.dev_ctx_ == next_instr.dev_ctx_ ) {
88
- next_instruction.direct_run_ .emplace_back (next_op_id);
89
- continue ;
90
- }
91
- // Always insert events between different stream
92
- auto new_event_var_ids = ParseEventVarIds (cur_instr, next_instr);
93
- event_var_ids.insert (event_var_ids.end (), new_event_var_ids.begin (),
94
- new_event_var_ids.end ());
95
-
96
- bool is_sync =
97
- (op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync );
98
- AssociateInputWithEvents (place, new_event_var_ids, &next_instr,
99
- var_id2event, is_sync);
100
-
101
- if (is_sync) { // GPU -> CPU
102
- next_instruction.synchronize_run_ .emplace_back (next_op_id);
103
- } else { // GPU -> GPU(different stream)
104
- next_instruction.event_wait_run_ .emplace_back (next_op_id);
105
- }
106
- }
107
- // Create events for these cross-stream vars
108
- VLOG (3 ) << cur_instr.kernel_func_ .operator_base_ ->Type ()
109
- << " event_var_ids.size: " << event_var_ids.size ();
110
- for (auto var_id : event_var_ids) {
111
- cur_instr.output_events_ .emplace_back (var_id, var_id2event->at (var_id),
112
- false /* not used*/ );
113
- }
114
- }
115
- }
116
- } // namespace
117
-
118
23
InterpreterCore::InterpreterCore (const platform::Place& place,
119
24
const ProgramDesc& main_prog,
120
25
VariableScope* global_scope,
@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
123
28
: place_(place),
124
29
main_program_ (main_prog),
125
30
global_scope_(global_scope),
126
- d2h_ctx_pool_({place}),
127
- h2d_ctx_pool_({place}) {
31
+ stream_analyzer_(place) {
128
32
is_build_ = false ;
129
33
130
34
feed_names_ = feed_names;
@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
199
103
Instruction temp_inst;
200
104
auto * op_base = op_list_[i];
201
105
temp_inst.dev_ctx_ =
202
- ParseDeviceContextForInstruction (vec_func_list_[i], *op_base);
106
+ stream_analyzer_. ParseDeviceContext (vec_func_list_[i], *op_base);
203
107
temp_inst.kernel_func_ .compute_func_ = vec_func_list_[i].kernel_func_ ;
204
108
temp_inst.kernel_func_ .operator_base_ = op_base;
205
109
temp_inst.input_index_ = vec_func_list_[i].input_index ;
@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
270
174
}
271
175
}
272
176
273
- ParseDirectAndEventRunOps (place_, vec_func_list_, filter_next, i,
274
- &var_id2event_, & vec_instruction_);
177
+ stream_analyzer_. Schedule ( vec_func_list_, filter_next, i,
178
+ &vec_instruction_);
275
179
276
180
for (auto inst_id : filter_next) {
277
181
dependecy_count_[inst_id]++;
@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
361
265
working_queue.pop ();
362
266
auto & instr_node = vec_instr[instr_id];
363
267
// step1 : stream_wait (non-block host) or sync (block host)
364
- StreamWaitEventOrSync (instr_node);
268
+ event_manager_. WaitEvent (instr_node, place_ );
365
269
// step2: run instruction
366
270
RunInstruction (instr_node);
367
271
++run_op_number;
@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
371
275
}
372
276
373
277
// step3: insert event for out_vars if needed
374
- RecordEventInstruction (instr_node, vec_func_list_[instr_id]);
278
+ event_manager_. RecordEvent (instr_node, vec_func_list_[instr_id], place_ );
375
279
376
280
// step4: update working_queue
377
281
auto & next_instr = instr_node.next_instruction_ .all_next_ops_ ;
@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
450
354
return dry_run_profiler_.GetCostInfo ();
451
355
}
452
356
453
- platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction (
454
- const OpFuncNode& op_func_node, const OperatorBase& op_base) {
455
- auto & op_type = op_base.Type ();
456
- auto * dev_ctx = op_func_node.dev_ctx_ ;
457
- if (op_type == interpretercore::kMemcpyH2D ) {
458
- VLOG (3 ) << " Get dev_ctx from d2h_context_pool_" ;
459
- dev_ctx = d2h_ctx_pool_.Get (place_);
460
- } else if (op_type == interpretercore::kMemcpyD2H ) {
461
- VLOG (3 ) << " Get dev_ctx from h2d_context_pool_" ;
462
- dev_ctx = h2d_ctx_pool_.Get (place_);
463
- }
464
-
465
- return dev_ctx;
466
- }
467
-
468
- void InterpreterCore::RecordEventInstruction (const Instruction& instruction,
469
- const OpFuncNode& op_func_node) {
470
- // If InterpreterCore in on CPUPlace, do nothing.
471
- if (platform::is_cpu_place (place_)) return ;
472
-
473
- for (auto & event : instruction.output_events_ ) {
474
- VLOG (3 ) << " Record event in out_var_id: " << event.var_id_ ;
475
- event.event_ ->Record (instruction.dev_ctx_ );
476
- }
477
- }
478
-
479
- void InterpreterCore::WaitOrSync (const std::vector<EventInter>& events,
480
- const platform::DeviceContext* dev_ctx) {
481
- for (auto & event_iter : events) {
482
- if (event_iter.is_sync_ ) {
483
- VLOG (3 ) << " host sync wait in_var_id " << event_iter.var_id_ ;
484
- event_iter.event_ ->Wait (platform::kCPU , dev_ctx);
485
- } else {
486
- VLOG (3 ) << " stream async wait in_var_id " << event_iter.var_id_ ;
487
- event_iter.event_ ->Wait (platform::kCUDA , dev_ctx);
488
- }
489
- }
490
- }
491
-
492
- void InterpreterCore::StreamWaitEventOrSync (const Instruction& instruction) {
493
- // If InterpreterCore in on CPUPlace, do nothing.
494
- if (platform::is_cpu_place (place_)) return ;
495
-
496
- VLOG (3 ) << " Deal StreamWaitEventOrSync for "
497
- << instruction.kernel_func_ .operator_base_ ->Type ();
498
- auto * dev_ctx = instruction.dev_ctx_ ;
499
-
500
- WaitOrSync (instruction.intput_events_ , dev_ctx);
501
- }
502
357
} // namespace framework
503
358
} // namespace paddle
0 commit comments