Skip to content

Commit 42e312a

Browse files
authored
Construct exec and ctx only once in cond op to speed up (#47009)
* cond infer apply exec seprate * fix bugs
1 parent 1cc482b commit 42e312a

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

paddle/fluid/operators/controlflow/conditional_block_infer_op.cc

+22-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/controlflow/conditional_block_op.h"
1616

17+
#ifdef PADDLE_WITH_MKLDNN
18+
#include "paddle/fluid/platform/mkldnn_helper.h"
19+
#endif
20+
21+
DECLARE_bool(use_mkldnn);
1722
namespace paddle {
1823
namespace framework {
1924
class OpDesc;
@@ -73,14 +78,29 @@ class ConditionalBlockInferOp : public ConditionalOp {
7378
scopes->front() = &scope.NewScope();
7479
auto &cur_scope = *scopes->front();
7580

76-
framework::Executor exec(dev_place);
7781
auto *block = Attr<framework::BlockDesc *>("sub_block");
7882
VLOG(3) << "Conditional block.idx = " << block->ID()
7983
<< ", scope = " << &cur_scope;
80-
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
84+
85+
if (!exec || !platform::is_same_place(exec->GetPlace(), dev_place)) {
86+
auto &pdesc = *block->Program();
87+
exec.reset(new framework::Executor(dev_place));
88+
if (FLAGS_use_mkldnn) exec->EnableMKLDNN(pdesc);
89+
ctx = exec->Prepare(
90+
pdesc, block->ID(), std::vector<std::string>(), false);
91+
#ifdef PADDLE_WITH_MKLDNN
92+
platform::AttachPointerHashToMKLDNNKey(exec.get(), dev_place);
93+
platform::RegisterModelLayout(ctx->ops_, dev_place);
94+
#endif
95+
}
96+
exec->RunPreparedContext(ctx.get(), &cur_scope, false, true, false);
8197
scope.DeleteScope(scopes->front());
8298
}
8399
}
100+
101+
private:
102+
mutable std::shared_ptr<framework::Executor> exec{nullptr};
103+
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx{nullptr};
84104
};
85105

86106
} // namespace operators

0 commit comments

Comments
 (0)