@@ -158,6 +158,8 @@ static void SubstituteNgraphOp(
158
158
ng_op_desc.SetAttr (" interval" , interval);
159
159
ng_op_desc.SetAttr (" engine_key" , engine_key);
160
160
ng_op_desc.SetAttr (" graph" , block_str);
161
+ ng_op_desc.SetInput (" Xs" , std::vector<std::string>(0 ));
162
+ ng_op_desc.SetOutput (" Ys" , std::vector<std::string>(0 ));
161
163
162
164
ops->erase (ops->begin () + interval[0 ], ops->begin () + interval[1 ]);
163
165
ops->insert (ops->begin () + interval[0 ],
@@ -223,20 +225,36 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
223
225
const platform::Place& place,
224
226
const framework::ExecutionContext& ctx)
225
227
: scope_(scope), place_(place) {
226
- std::string serialized_graph = ctx.Attr <std::string>(" graph" );
227
- auto interval = ctx.Attr <std::vector<int >>(" interval" );
228
- std::string engine_key = ctx.Attr <std::string>(" engine_key" );
229
-
230
228
var_in_node_map_ = std::make_shared<
231
229
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
232
230
233
231
var_node_map_ = std::make_shared<
234
232
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>();
235
233
236
- GetNgFunction (engine_key, interval );
234
+ GetNgFunction (ctx );
237
235
}
238
236
239
- void NgraphEngine::Prepare (const std::vector<int >& interval) {
237
+ void NgraphEngine::Prepare (const framework::ExecutionContext& ctx) {
238
+ auto interval = ctx.Attr <std::vector<int >>(" interval" );
239
+ std::string serialized_graph = ctx.Attr <std::string>(" graph" );
240
+
241
+ auto input_vars = ctx.Inputs (" Xs" );
242
+ if (!input_vars.empty ()) {
243
+ feed_vars = input_vars;
244
+ var_in_ = input_vars;
245
+ }
246
+ auto output_vars = ctx.Outputs (" Ys" );
247
+ if (!output_vars.empty ()) {
248
+ var_out_ = output_vars;
249
+ }
250
+
251
+ framework::proto::BlockDesc block_proto;
252
+ if (!serialized_graph.empty ()) block_proto.ParseFromString (serialized_graph);
253
+ framework::BlockDesc block_desc (nullptr , &block_proto);
254
+ if (!serialized_graph.empty ()) {
255
+ NgraphEngine::p_bdesc = &block_desc;
256
+ }
257
+
240
258
bool has_fetch = false , is_full = false ;
241
259
for (auto & var : p_bdesc->AllVars ()) {
242
260
if (!(var->GetType () == framework::proto::VarType::SELECTED_ROWS ||
@@ -316,7 +334,15 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
316
334
op_state_ = OpState::UNKNOWN;
317
335
}
318
336
319
- BuildNgIO (ops_desc, interval);
337
+ if (var_in_.empty () && var_out_.empty ()) {
338
+ BuildNgIO (ops_desc, interval);
339
+ }
340
+ for (size_t i = 0 ; i < var_in_.size (); ++i) {
341
+ auto var_name = var_in_[i];
342
+ if (persistables_.find (var_name) == persistables_.end ()) {
343
+ var_in_updates_.emplace_back (i);
344
+ }
345
+ }
320
346
}
321
347
322
348
void NgraphEngine::BuildNgIO (const std::vector<framework::OpDesc*>& ops_desc,
@@ -392,13 +418,6 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
392
418
}
393
419
}
394
420
}
395
-
396
- for (size_t i = 0 ; i < var_in_.size (); ++i) {
397
- auto var_name = var_in_[i];
398
- if (persistables_.find (var_name) == persistables_.end ()) {
399
- var_in_updates_.emplace_back (i);
400
- }
401
- }
402
421
}
403
422
404
423
void NgraphEngine::GetNgInputShape () {
@@ -434,7 +453,6 @@ void NgraphEngine::BuildNgNodes() {
434
453
}
435
454
}
436
455
}
437
-
438
456
NgraphBridge ngb (var_node_map_);
439
457
for (auto & op : fused_ops_) {
440
458
ngb.BuildNgNode (op);
@@ -448,8 +466,8 @@ void NgraphEngine::RunInferShape() {
448
466
}
449
467
}
450
468
451
- void NgraphEngine::BuildNgFunction (const std::vector< int >& interval ) {
452
- Prepare (interval );
469
+ void NgraphEngine::BuildNgFunction (const framework::ExecutionContext& ctx ) {
470
+ Prepare (ctx );
453
471
RunInferShape ();
454
472
GetNgInputShape ();
455
473
BuildNgNodes ();
@@ -472,12 +490,13 @@ void NgraphEngine::BuildNgFunction(const std::vector<int>& interval) {
472
490
std::make_shared<ngraph::Function>(func_outputs, func_inputs);
473
491
}
474
492
475
- void NgraphEngine::GetNgFunction (std::string engine_key,
476
- const std::vector<int >& interval) {
493
+ void NgraphEngine::GetNgFunction (const framework::ExecutionContext& ctx) {
494
+ auto interval = ctx.Attr <std::vector<int >>(" interval" );
495
+ std::string engine_key = ctx.Attr <std::string>(" engine_key" );
477
496
bool use_cache = true ;
478
497
if (use_cache) {
479
498
this ->func_cache_key_ = " " ;
480
- for (int i = 0 ; i < std::min ( static_cast <int >(feed_vars.size ()), 10 ); ++i) {
499
+ for (int i = 0 ; i < static_cast <int >(feed_vars.size ()); ++i) {
481
500
auto * var = scope_.FindVar (feed_vars[i]);
482
501
if (var && var->IsType <framework::LoDTensor>()) {
483
502
auto * tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar (*var);
@@ -507,7 +526,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
507
526
}
508
527
509
528
if (engine_cache.find (func_cache_key_) == engine_cache.end ()) {
510
- BuildNgFunction (interval );
529
+ BuildNgFunction (ctx );
511
530
engine_cache[func_cache_key_].ngraph_function = this ->ngraph_function_ ;
512
531
engine_cache[func_cache_key_].persistables = this ->persistables_ ;
513
532
engine_cache[func_cache_key_].var_in_updates = this ->var_in_updates_ ;
@@ -516,7 +535,7 @@ void NgraphEngine::GetNgFunction(std::string engine_key,
516
535
engine_cache[func_cache_key_].is_test = this ->is_test_ ;
517
536
}
518
537
} else {
519
- BuildNgFunction (interval );
538
+ BuildNgFunction (ctx );
520
539
}
521
540
}
522
541
0 commit comments