@@ -14,6 +14,11 @@ limitations under the License. */
14
14
15
15
#include " paddle/fluid/operators/controlflow/conditional_block_op.h"
16
16
17
+ #ifdef PADDLE_WITH_MKLDNN
18
+ #include " paddle/fluid/platform/mkldnn_helper.h"
19
+ #endif
20
+
21
+ DECLARE_bool (use_mkldnn);
17
22
namespace paddle {
18
23
namespace framework {
19
24
class OpDesc ;
@@ -73,14 +78,29 @@ class ConditionalBlockInferOp : public ConditionalOp {
73
78
scopes->front () = &scope.NewScope ();
74
79
auto &cur_scope = *scopes->front ();
75
80
76
- framework::Executor exec (dev_place);
77
81
auto *block = Attr<framework::BlockDesc *>(" sub_block" );
78
82
VLOG (3 ) << " Conditional block.idx = " << block->ID ()
79
83
<< " , scope = " << &cur_scope;
80
- exec.Run (*block->Program (), &cur_scope, block->ID (), false );
84
+
85
+ if (!exec || !platform::is_same_place (exec->GetPlace (), dev_place)) {
86
+ auto &pdesc = *block->Program ();
87
+ exec.reset (new framework::Executor (dev_place));
88
+ if (FLAGS_use_mkldnn) exec->EnableMKLDNN (pdesc);
89
+ ctx = exec->Prepare (
90
+ pdesc, block->ID (), std::vector<std::string>(), false );
91
+ #ifdef PADDLE_WITH_MKLDNN
92
+ platform::AttachPointerHashToMKLDNNKey (exec.get (), dev_place);
93
+ platform::RegisterModelLayout (ctx->ops_ , dev_place);
94
+ #endif
95
+ }
96
+ exec->RunPreparedContext (ctx.get (), &cur_scope, false , true , false );
81
97
scope.DeleteScope (scopes->front ());
82
98
}
83
99
}
100
+
101
+ private:
102
+ mutable std::shared_ptr<framework::Executor> exec{nullptr };
103
+ mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx{nullptr };
84
104
};
85
105
86
106
} // namespace operators
0 commit comments