Skip to content

Commit 8923612

Browse files
baojun-nervanatensor-tang
authored andcommitted
NGraph enable parse serialized graph test=develop (#17453)
1 parent cf5d271 commit 8923612

File tree

2 files changed

+44
-25
lines changed

2 files changed

+44
-25
lines changed

paddle/fluid/operators/ngraph/ngraph_engine.cc

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ static void SubstituteNgraphOp(
158158
ng_op_desc.SetAttr("interval", interval);
159159
ng_op_desc.SetAttr("engine_key", engine_key);
160160
ng_op_desc.SetAttr("graph", block_str);
161+
ng_op_desc.SetInput("Xs", std::vector<std::string>(0));
162+
ng_op_desc.SetOutput("Ys", std::vector<std::string>(0));
161163

162164
ops->erase(ops->begin() + interval[0], ops->begin() + interval[1]);
163165
ops->insert(ops->begin() + interval[0],
@@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
223225
const platform::Place& place,
224226
const framework::ExecutionContext& ctx)
225227
: scope_(scope), place_(place) {
226-
std::string serialized_graph = ctx.Attr<std::string>("graph");
227-
auto interval = ctx.Attr<std::vector<int>>("interval");
228-
std::string engine_key = ctx.Attr<std::string>("engine_key");
229-
230228
var_in_node_map_ = std::make_shared<
231229
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
232230

233231
var_node_map_ = std::make_shared<
234232
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
235233

236-
GetNgFunction(engine_key, interval);
234+
GetNgFunction(ctx);
237235
}
238236

239-
void NgraphEngine::Prepare(const std::vector<int>& interval) {
237+
void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
238+
auto interval = ctx.Attr<std::vector<int>>("interval");
239+
std::string serialized_graph = ctx.Attr<std::string>("graph");
240+
241+
auto input_vars = ctx.Inputs("Xs");
242+
if (!input_vars.empty()) {
243+
feed_vars = input_vars;
244+
var_in_ = input_vars;
245+
}
246+
auto output_vars = ctx.Outputs("Ys");
247+
if (!output_vars.empty()) {
248+
var_out_ = output_vars;
249+
}
250+
251+
framework::proto::BlockDesc block_proto;
252+
if (!serialized_graph.empty()) block_proto.ParseFromString(serialized_graph);
253+
framework::BlockDesc block_desc(nullptr, &block_proto);
254+
if (!serialized_graph.empty()) {
255+
NgraphEngine::p_bdesc = &block_desc;
256+
}
257+
240258
bool has_fetch = false, is_full = false;
241259
for (auto& var : p_bdesc->AllVars()) {
242260
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
@@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
316334
op_state_ = OpState::UNKNOWN;
317335
}
318336

319-
BuildNgIO(ops_desc, interval);
337+
if (var_in_.empty() && var_out_.empty()) {
338+
BuildNgIO(ops_desc, interval);
339+
}
340+
for (size_t i = 0; i < var_in_.size(); ++i) {
341+
auto var_name = var_in_[i];
342+
if (persistables_.find(var_name) == persistables_.end()) {
343+
var_in_updates_.emplace_back(i);
344+
}
345+
}
320346
}
321347

322348
void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
@@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
392418
}
393419
}
394420
}
395-
396-
for (size_t i = 0; i < var_in_.size(); ++i) {
397-
auto var_name = var_in_[i];
398-
if (persistables_.find(var_name) == persistables_.end()) {
399-
var_in_updates_.emplace_back(i);
400-
}
401-
}
402421
}
403422

404423
void NgraphEngine::GetNgInputShape() {
@@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() {
434453
}
435454
}
436455
}
437-
438456
NgraphBridge ngb(var_node_map_);
439457
for (auto& op : fused_ops_) {
440458
ngb.BuildNgNode(op);
@@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() {
448466
}
449467
}
450468

451-
void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
452-
Prepare(interval);
469+
void NgraphEngine::BuildNgFunction(const framework::ExecutionContext& ctx) {
470+
Prepare(ctx);
453471
RunInferShape();
454472
GetNgInputShape();
455473
BuildNgNodes();
@@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
472490
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
473491
}
474492

475-
void NgraphEngine::GetNgFunction(std::string engine_key,
476-
const std::vector<int>& interval) {
493+
void NgraphEngine::GetNgFunction(const framework::ExecutionContext& ctx) {
494+
auto interval = ctx.Attr<std::vector<int>>("interval");
495+
std::string engine_key = ctx.Attr<std::string>("engine_key");
477496
bool use_cache = true;
478497
if (use_cache) {
479498
this->func_cache_key_ = "";
480-
for (int i = 0; i < std::min(static_cast<int>(feed_vars.size()), 10); ++i) {
499+
for (int i = 0; i < static_cast<int>(feed_vars.size()); ++i) {
481500
auto* var = scope_.FindVar(feed_vars[i]);
482501
if (var && var->IsType<framework::LoDTensor>()) {
483502
auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
@@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
507526
}
508527

509528
if (engine_cache.find(func_cache_key_) == engine_cache.end()) {
510-
BuildNgFunction(interval);
529+
BuildNgFunction(ctx);
511530
engine_cache[func_cache_key_].ngraph_function = this->ngraph_function_;
512531
engine_cache[func_cache_key_].persistables = this->persistables_;
513532
engine_cache[func_cache_key_].var_in_updates = this->var_in_updates_;
@@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
516535
engine_cache[func_cache_key_].is_test = this->is_test_;
517536
}
518537
} else {
519-
BuildNgFunction(interval);
538+
BuildNgFunction(ctx);
520539
}
521540
}
522541

paddle/fluid/operators/ngraph/ngraph_engine.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class NgraphEngine {
101101
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
102102
var_node_map_;
103103
// prepare info for ngraph engine need
104-
void Prepare(const std::vector<int>& interval);
104+
void Prepare(const framework::ExecutionContext& ctx);
105105
// get ngraph engine input and output list
106106
void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
107107
const std::vector<int>& interval);
@@ -112,9 +112,9 @@ class NgraphEngine {
112112
// run paddle RuntimeInferShape to get the tensor shape
113113
void RunInferShape();
114114
// build ngraph function call
115-
void BuildNgFunction(const std::vector<int>& interval);
115+
void BuildNgFunction(const framework::ExecutionContext& ctx);
116116
// Check cache for ngraph function or otherwise build the function
117-
void GetNgFunction(std::string engine_key, const std::vector<int>& interval);
117+
void GetNgFunction(const framework::ExecutionContext& ctx);
118118
};
119119

120120
} // namespace operators

0 commit comments

Comments
 (0)