Skip to content

Commit 0306fa1

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into zyf_slice
2 parents 5308047 + b97af7d commit 0306fa1

29 files changed

+2925
-495
lines changed

paddle/fluid/framework/new_executor/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ graph_to_program_pass variable_helper timer monitor)
55
cc_library(workqueue SRCS workqueue.cc DEPS enforce)
66
cc_library(interpretercore_garbage_collector SRCS interpretercore_garbage_collector.cc DEPS workqueue ${DEVICE_EVENT_LIBS})
77
cc_library(interpretercore_util SRCS interpretercore_util.cc DEPS ${INTERPRETERCORE_DEPS})
8-
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector)
8+
cc_library(event_manager SRCS event_manager.cc DEPS ${DEVICE_EVENT_LIBS} glog)
9+
cc_library(stream_analyzer SRCS stream_analyzer.cc DEPS ${DEVICE_EVENT_LIBS} glog device_context)
10+
cc_library(interpretercore SRCS interpretercore.cc DEPS workqueue ${DEVICE_EVENT_LIBS} interpretercore_util interpretercore_garbage_collector stream_analyzer event_manager)
911
cc_library(standalone_executor SRCS standalone_executor.cc DEPS interpretercore)
1012
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue)
1113
# cc_binary(standalone_executor_test SRCS standalone_executor_test.cc DEPS interpretercore standalone_executor operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/new_executor/event_manager.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
20+
void EventManager::WaitEvent(const Instruction& instruction,
21+
const platform::Place& place) {
22+
// If InterpreterCore in on CPUPlace, do nothing.
23+
if (platform::is_cpu_place(place)) return;
24+
25+
VLOG(3) << "Deal StreamWaitEventOrSync for "
26+
<< instruction.kernel_func_.operator_base_->Type();
27+
auto* dev_ctx = instruction.dev_ctx_;
28+
29+
WaitOrSync(instruction.intput_events_, dev_ctx);
30+
}
31+
32+
void EventManager::RecordEvent(const Instruction& instruction,
33+
const OpFuncNode& op_func_node,
34+
const platform::Place& place) {
35+
// If InterpreterCore in on CPUPlace, do nothing.
36+
if (platform::is_cpu_place(place)) return;
37+
38+
for (auto& event : instruction.output_events_) {
39+
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
40+
event.event_->Record(instruction.dev_ctx_);
41+
}
42+
}
43+
44+
void EventManager::WaitOrSync(const std::vector<EventInter>& events,
45+
const platform::DeviceContext* dev_ctx) {
46+
for (auto& event_iter : events) {
47+
if (event_iter.is_sync_) {
48+
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
49+
event_iter.event_->Wait(platform::kCPU, dev_ctx);
50+
} else {
51+
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
52+
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
53+
}
54+
}
55+
}
56+
57+
} // namespace framework
58+
} // namespace paddle
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
17+
18+
namespace paddle {
19+
namespace framework {
20+
21+
class EventManager {
22+
public:
23+
void RecordEvent(const Instruction& instruction,
24+
const OpFuncNode& op_func_node,
25+
const platform::Place& place);
26+
27+
void WaitEvent(const Instruction& instruction, const platform::Place& place);
28+
29+
private:
30+
void WaitOrSync(const std::vector<EventInter>& events,
31+
const platform::DeviceContext* dev_ctx);
32+
};
33+
34+
} // namespace framework
35+
} // namespace paddle

paddle/fluid/framework/new_executor/interpretercore.cc

Lines changed: 6 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -20,101 +20,6 @@
2020
namespace paddle {
2121
namespace framework {
2222

23-
namespace {
24-
25-
/*
26-
* Parse the var_ids that need to be associated with an event.
27-
* The caller should guarantee front_op and back_op satisfy the
28-
* following conditions:
29-
* 1. kQueueAsync -> kQueueAsync
30-
* 2. kQueueAsync -> kQueueSync
31-
*
32-
* For example: matmul(gpu) -> out_var -> memcpy_d2h
33-
* out_var should be associated with an event.
34-
*/
35-
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
36-
const Instruction& next_instr) {
37-
std::unordered_set<size_t> unique_var_ids;
38-
for (auto& item : cur_instr.output_index_) {
39-
unique_var_ids.insert(item.second.begin(), item.second.end());
40-
}
41-
42-
std::vector<size_t> new_event_var_ids;
43-
for (auto& item : next_instr.input_index_) {
44-
for (auto var_id : item.second) {
45-
if (unique_var_ids.count(var_id) > 0) {
46-
new_event_var_ids.push_back(var_id);
47-
}
48-
}
49-
}
50-
return new_event_var_ids;
51-
}
52-
53-
void AssociateInputWithEvents(
54-
const platform::Place& place, const std::vector<size_t>& new_event_var_id,
55-
Instruction* next_instr,
56-
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
57-
bool is_sync) {
58-
for (auto var_id : new_event_var_id) {
59-
if (var_id2event->count(var_id) == 0) {
60-
auto device_event = std::make_shared<platform::DeviceEvent>(
61-
place, platform::GenerateDeviceEventFlag());
62-
var_id2event->emplace(var_id, std::move(device_event));
63-
}
64-
// Add events for next_instr.inputs
65-
next_instr->intput_events_.emplace_back(var_id, var_id2event->at(var_id),
66-
is_sync);
67-
}
68-
}
69-
70-
void ParseDirectAndEventRunOps(
71-
const platform::Place& place, const std::vector<OpFuncNode>& op_func_nodes,
72-
const std::vector<size_t>& downstream_ops, size_t op_index,
73-
std::map<size_t, std::shared_ptr<platform::DeviceEvent>>* var_id2event,
74-
std::vector<Instruction>* instructions) {
75-
auto& op_func_type = op_func_nodes[op_index].type_;
76-
auto& cur_instr = instructions->at(op_index);
77-
auto& next_instruction = cur_instr.next_instruction_;
78-
79-
if (op_func_type == OpFuncType::kQueueSync) {
80-
// all downstream ops of kQueueSync can directly run, such as CPU -> Any
81-
next_instruction.direct_run_ = downstream_ops;
82-
} else { // kQueueAsync
83-
std::vector<size_t> event_var_ids;
84-
for (auto next_op_id : downstream_ops) {
85-
auto& next_instr = instructions->at(next_op_id);
86-
// case 1: GPU -> GPU(same stream)
87-
if (cur_instr.dev_ctx_ == next_instr.dev_ctx_) {
88-
next_instruction.direct_run_.emplace_back(next_op_id);
89-
continue;
90-
}
91-
// Always insert events between different stream
92-
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
93-
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
94-
new_event_var_ids.end());
95-
96-
bool is_sync =
97-
(op_func_nodes[next_op_id].type_ == OpFuncType::kQueueSync);
98-
AssociateInputWithEvents(place, new_event_var_ids, &next_instr,
99-
var_id2event, is_sync);
100-
101-
if (is_sync) { // GPU -> CPU
102-
next_instruction.synchronize_run_.emplace_back(next_op_id);
103-
} else { // GPU -> GPU(different stream)
104-
next_instruction.event_wait_run_.emplace_back(next_op_id);
105-
}
106-
}
107-
// Create events for these cross-stream vars
108-
VLOG(3) << cur_instr.kernel_func_.operator_base_->Type()
109-
<< " event_var_ids.size: " << event_var_ids.size();
110-
for (auto var_id : event_var_ids) {
111-
cur_instr.output_events_.emplace_back(var_id, var_id2event->at(var_id),
112-
false /*not used*/);
113-
}
114-
}
115-
}
116-
} // namespace
117-
11823
InterpreterCore::InterpreterCore(const platform::Place& place,
11924
const ProgramDesc& main_prog,
12025
VariableScope* global_scope,
@@ -123,8 +28,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
12328
: place_(place),
12429
main_program_(main_prog),
12530
global_scope_(global_scope),
126-
d2h_ctx_pool_({place}),
127-
h2d_ctx_pool_({place}) {
31+
stream_analyzer_(place) {
12832
is_build_ = false;
12933

13034
feed_names_ = feed_names;
@@ -199,7 +103,7 @@ void InterpreterCore::Convert() {
199103
Instruction temp_inst;
200104
auto* op_base = op_list_[i];
201105
temp_inst.dev_ctx_ =
202-
ParseDeviceContextForInstruction(vec_func_list_[i], *op_base);
106+
stream_analyzer_.ParseDeviceContext(vec_func_list_[i], *op_base);
203107
temp_inst.kernel_func_.compute_func_ = vec_func_list_[i].kernel_func_;
204108
temp_inst.kernel_func_.operator_base_ = op_base;
205109
temp_inst.input_index_ = vec_func_list_[i].input_index;
@@ -270,8 +174,8 @@ void InterpreterCore::Convert() {
270174
}
271175
}
272176

273-
ParseDirectAndEventRunOps(place_, vec_func_list_, filter_next, i,
274-
&var_id2event_, &vec_instruction_);
177+
stream_analyzer_.Schedule(vec_func_list_, filter_next, i,
178+
&vec_instruction_);
275179

276180
for (auto inst_id : filter_next) {
277181
dependecy_count_[inst_id]++;
@@ -361,7 +265,7 @@ void InterpreterCore::ExecuteInstructionList(
361265
working_queue.pop();
362266
auto& instr_node = vec_instr[instr_id];
363267
// step1 : stream_wait (non-block host) or sync (block host)
364-
StreamWaitEventOrSync(instr_node);
268+
event_manager_.WaitEvent(instr_node, place_);
365269
// step2: run instruction
366270
RunInstruction(instr_node);
367271
++run_op_number;
@@ -371,7 +275,7 @@ void InterpreterCore::ExecuteInstructionList(
371275
}
372276

373277
// step3: insert event for out_vars if needed
374-
RecordEventInstruction(instr_node, vec_func_list_[instr_id]);
278+
event_manager_.RecordEvent(instr_node, vec_func_list_[instr_id], place_);
375279

376280
// step4: update working_queue
377281
auto& next_instr = instr_node.next_instruction_.all_next_ops_;
@@ -450,54 +354,5 @@ const CostInfo& InterpreterCore::DryRun(
450354
return dry_run_profiler_.GetCostInfo();
451355
}
452356

453-
platform::DeviceContext* InterpreterCore::ParseDeviceContextForInstruction(
454-
const OpFuncNode& op_func_node, const OperatorBase& op_base) {
455-
auto& op_type = op_base.Type();
456-
auto* dev_ctx = op_func_node.dev_ctx_;
457-
if (op_type == interpretercore::kMemcpyH2D) {
458-
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
459-
dev_ctx = d2h_ctx_pool_.Get(place_);
460-
} else if (op_type == interpretercore::kMemcpyD2H) {
461-
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
462-
dev_ctx = h2d_ctx_pool_.Get(place_);
463-
}
464-
465-
return dev_ctx;
466-
}
467-
468-
void InterpreterCore::RecordEventInstruction(const Instruction& instruction,
469-
const OpFuncNode& op_func_node) {
470-
// If InterpreterCore in on CPUPlace, do nothing.
471-
if (platform::is_cpu_place(place_)) return;
472-
473-
for (auto& event : instruction.output_events_) {
474-
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
475-
event.event_->Record(instruction.dev_ctx_);
476-
}
477-
}
478-
479-
void InterpreterCore::WaitOrSync(const std::vector<EventInter>& events,
480-
const platform::DeviceContext* dev_ctx) {
481-
for (auto& event_iter : events) {
482-
if (event_iter.is_sync_) {
483-
VLOG(3) << "host sync wait in_var_id " << event_iter.var_id_;
484-
event_iter.event_->Wait(platform::kCPU, dev_ctx);
485-
} else {
486-
VLOG(3) << "stream async wait in_var_id " << event_iter.var_id_;
487-
event_iter.event_->Wait(platform::kCUDA, dev_ctx);
488-
}
489-
}
490-
}
491-
492-
void InterpreterCore::StreamWaitEventOrSync(const Instruction& instruction) {
493-
// If InterpreterCore in on CPUPlace, do nothing.
494-
if (platform::is_cpu_place(place_)) return;
495-
496-
VLOG(3) << "Deal StreamWaitEventOrSync for "
497-
<< instruction.kernel_func_.operator_base_->Type();
498-
auto* dev_ctx = instruction.dev_ctx_;
499-
500-
WaitOrSync(instruction.intput_events_, dev_ctx);
501-
}
502357
} // namespace framework
503358
} // namespace paddle

paddle/fluid/framework/new_executor/interpretercore.h

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
#include <unordered_map>
2020
#include <vector>
2121

22+
#include "paddle/fluid/framework/new_executor/event_manager.h"
2223
#include "paddle/fluid/framework/new_executor/interpretercore_garbage_collector.h"
2324
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
2425
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
2526
#include "paddle/fluid/framework/new_executor/profiler.h"
27+
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
2628
#include "paddle/fluid/framework/new_executor/workqueue.h"
2729
#include "paddle/fluid/framework/program_desc.h"
2830
#include "paddle/fluid/framework/tensor.h"
@@ -64,17 +66,6 @@ class InterpreterCore {
6466
const VariableScope& var_scope, const platform::Place& place,
6567
std::vector<VariableMetaInfo>& working_var_ref); // NOLINT
6668

67-
platform::DeviceContext* ParseDeviceContextForInstruction(
68-
const OpFuncNode& op_func_node, const OperatorBase& op_base);
69-
70-
void RecordEventInstruction(const Instruction& instruction,
71-
const OpFuncNode& op_func_node);
72-
73-
void WaitOrSync(const std::vector<EventInter>& events,
74-
const platform::DeviceContext* dev_ctx);
75-
76-
void StreamWaitEventOrSync(const Instruction& instruction);
77-
7869
void AddFetch(const std::vector<std::string>& fetch_names);
7970

8071
bool is_build_;
@@ -83,9 +74,6 @@ class InterpreterCore {
8374
ProgramDesc main_program_;
8475
VariableScope* global_scope_;
8576

86-
platform::DeviceContextPool d2h_ctx_pool_;
87-
platform::DeviceContextPool h2d_ctx_pool_;
88-
8977
std::vector<Instruction> vec_instruction_;
9078
InstructionInfo instruction_info_;
9179
std::vector<size_t> dependecy_count_;
@@ -99,8 +87,8 @@ class InterpreterCore {
9987
std::vector<std::string> feed_names_;
10088

10189
InterpreterProfiler dry_run_profiler_;
102-
103-
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
90+
StreamAnalyzer stream_analyzer_;
91+
EventManager event_manager_;
10492

10593
InterpreterCoreGarbageCollector gc_;
10694
std::vector<paddle::platform::DeviceEvent> gc_event_;

paddle/fluid/framework/new_executor/interpretercore_util.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
476476

477477
namespace interpretercore {
478478

479-
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
480-
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
481-
482479
std::string get_memcpy_type(const platform::Place& src_place,
483480
const platform::Place& dst_place);
484481

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525
namespace paddle {
2626
namespace framework {
2727

28+
namespace interpretercore {
29+
static constexpr char kMemcpyH2D[] = "memcpy_h2d";
30+
static constexpr char kMemcpyD2H[] = "memcpy_d2h";
31+
} // namespace interpretercore
32+
2833
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
2934
using OpKernelMap =
3035
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;

0 commit comments

Comments
 (0)