Skip to content

Commit 67c8dad

Browse files
author
chengduo
authored
Add Event in ScopeBuffer Executor (#17667)
* add event for fast executor and add threads for scopebuffer executor test=develop
1 parent bba57cd commit 67c8dad

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/fluid/framework/details/fetch_op_handle.h"
2121
#include "paddle/fluid/framework/details/multi_devices_helper.h"
2222
#include "paddle/fluid/framework/ir/graph_helper.h"
23+
#include "paddle/fluid/platform/profiler.h"
2324

2425
namespace paddle {
2526
namespace framework {
@@ -50,6 +51,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
5051
FeedFetchList FastThreadedSSAGraphExecutor::Run(
5152
const std::vector<std::string> &fetch_tensors) {
5253
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
54+
std::unique_ptr<platform::RecordEvent> event(
55+
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
5356
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
5457
op_deps = atomic_op_deps_.get();
5558
PrepareAtomicOpDeps();
@@ -64,7 +67,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
6467

6568
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
6669
&fetch_ops, &ready_fetch_ops);
67-
70+
event.reset(nullptr);
6871
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
6972
// If the num_threads is 1, we can record the order of operator's
7073
// execution in the first iteration, and in subsequent iterations,

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc

+24-22
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
3636
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
3737
const std::vector<std::string> &fetch_tensors) {
3838
if (drop_scope_counter_ == 0) {
39-
// Create local scopes.
40-
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
41-
auto &scope = *it;
42-
Scope &local_scope = scope->NewScope();
43-
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
44-
&local_scope;
45-
46-
for (auto &info : var_infos_) {
47-
if (scope->FindVar(info.name_) != nullptr) {
48-
continue;
49-
}
50-
51-
if (info.persistable_) { // Persistable
52-
InitializeVariable(scope->Var(info.name_), info.type_);
53-
} else {
54-
InitializeVariable(local_scope.Var(info.name_), info.type_);
55-
}
56-
}
57-
}
39+
platform::RecordEvent e("InitLocalExeScopes");
40+
PrepareLocalExeScopes();
5841
}
42+
5943
std::vector<framework::LoDTensor> fetch_data;
6044
std::exception_ptr eptr = nullptr;
6145
try {
@@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
6448
eptr = std::current_exception();
6549
}
6650

67-
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
6851
++drop_scope_counter_;
69-
7052
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
7153
DropLocalExeScopes();
7254
}
@@ -78,11 +60,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
7860
}
7961

8062
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
63+
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
8164
drop_scope_counter_ = 0;
8265
for (auto p : places_) {
8366
platform::DeviceContextPool::Instance().Get(p)->Wait();
8467
}
85-
8668
for (auto &scope : local_scopes_) {
8769
auto &local_scope =
8870
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
@@ -91,6 +73,26 @@ void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
9173
}
9274
}
9375

76+
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
77+
// Create local scopes.
78+
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
79+
auto &scope = *it;
80+
Scope &local_scope = scope->NewScope();
81+
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
82+
83+
for (auto &info : var_infos_) {
84+
if (scope->FindVar(info.name_) != nullptr) {
85+
continue;
86+
}
87+
if (info.persistable_) { // Persistable
88+
InitializeVariable(scope->Var(info.name_), info.type_);
89+
} else {
90+
InitializeVariable(local_scope.Var(info.name_), info.type_);
91+
}
92+
}
93+
}
94+
}
95+
9496
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
9597
return drop_scope_counter_ == 0;
9698
}

paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
16+
#include <ThreadPool.h>
17+
#include <list>
1718
#include <memory>
1819
#include <string>
1920
#include <vector>
@@ -51,6 +52,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
5152

5253
bool NeedCreateLocalExeScope();
5354

55+
void PrepareLocalExeScopes();
56+
5457
private:
5558
size_t drop_scope_counter_{0};
5659
ExecutionStrategy strategy_;

paddle/fluid/framework/parallel_executor.cc

+1
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
586586

587587
platform::RecordBlock b(0);
588588
if (member_->HasGarbageCollectors()) {
589+
platform::RecordEvent event("PrepareGarbageCollectors");
589590
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
590591
}
591592

0 commit comments

Comments
 (0)