Skip to content

Commit ff158dd

Browse files
authored
[NewExe] cache infer_meta for program_interpreter (#58575)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * fix * refine
1 parent a5b9a74 commit ff158dd

File tree

9 files changed

+105
-13
lines changed

9 files changed

+105
-13
lines changed

paddle/fluid/framework/details/op_registry.h

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,38 @@ struct OpInfoFiller<T, kVarTypeInference> {
318318
}
319319
};
320320

321+
template <typename T, typename = void>
322+
struct InferMetaTrait {
323+
static void call(const char* op_type UNUSED, OpInfo* info) {
324+
info->infer_shape_ = [](InferShapeContext* ctx) {
325+
T inference;
326+
inference(ctx);
327+
};
328+
}
329+
};
330+
321331
template <typename T>
322-
struct OpInfoFiller<T, kShapeInference> {
323-
void operator()(const char* op_type UNUSED, OpInfo* info) const {
324-
// Note: if fill InferShapeFN by this Filler, the infershape here
325-
// will overwrite the op->InferShape func registered in kOperator Filler
332+
struct InferMetaTrait<T,
333+
decltype(std::declval<T>().infer_meta_(
334+
std::declval<phi::InferMetaContext*>()))> {
335+
static void call(const char* op_type UNUSED, OpInfo* info) {
326336
info->infer_shape_ = [](InferShapeContext* ctx) {
327337
T inference;
328338
inference(ctx);
329339
};
340+
info->infer_meta_ = [](phi::InferMetaContext* ctx) {
341+
T inference;
342+
inference.infer_meta_(ctx);
343+
};
344+
}
345+
};
346+
347+
template <typename T>
348+
struct OpInfoFiller<T, kShapeInference> {
349+
void operator()(const char* op_type UNUSED, OpInfo* info) const {
350+
// Note: if fill InferShapeFN by this Filler, the infershape here
351+
// will overwrite the op->InferShape func registered in kOperator Filler
352+
InferMetaTrait<T>::call(op_type, info);
330353
}
331354
};
332355

paddle/fluid/framework/infershape_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
151151
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
152152
fn(&infer_meta_context); \
153153
} \
154+
void infer_meta_(phi::InferMetaContext* ctx) const { fn(ctx); } \
154155
}
155156

156157
} // namespace framework

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <unordered_map>
2020
#include <vector>
2121

22+
#include "paddle/fluid/framework/infershape_utils.h"
2223
#include "paddle/fluid/platform/profiler/event_tracing.h"
2324

2425
namespace paddle {
@@ -237,7 +238,8 @@ const std::vector<size_t>& Instruction::GCCheckVars() const {
237238
}
238239

239240
void Instruction::ResetContext(const VariableValueMap& in_vars,
240-
const VariableValueMap& out_vars) {
241+
const VariableValueMap& out_vars,
242+
const std::string& op_name) {
241243
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
242244
infershape_ctx_.reset(
243245
new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
@@ -246,16 +248,37 @@ void Instruction::ResetContext(const VariableValueMap& in_vars,
246248
static framework::Scope scope_;
247249
execution_ctx_.reset(
248250
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
251+
252+
auto op_with_kernel =
253+
dynamic_cast<const framework::OperatorWithKernel*>(OpBase());
254+
if (op_with_kernel != nullptr && op_with_kernel->Info().infer_meta_) {
255+
if (infershape_ctx_->HasRuntimeAttributes() == false) {
256+
compat_infermeta_ctx_ = paddle::framework::BuildInferMetaContext(
257+
infershape_ctx_.get(), op_name);
258+
can_use_infermeta_ctx_ = true;
259+
}
260+
}
249261
}
250262

251263
void Instruction::ResetContextWithScope(const VariableValueMap& in_vars,
252264
const VariableValueMap& out_vars,
253-
const framework::Scope& scope) {
265+
const framework::Scope& scope,
266+
const std::string& op_name) {
254267
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
255268
infershape_ctx_.reset(
256269
new RuntimeInferShapeContext(*OpBase(), *runtime_ctx_.get()));
257270
execution_ctx_.reset(
258271
new ExecutionContext(*OpBase(), scope, dev_ctx_, *runtime_ctx_.get()));
272+
273+
auto op_with_kernel =
274+
dynamic_cast<const framework::OperatorWithKernel*>(OpBase());
275+
if (op_with_kernel != nullptr && op_with_kernel->Info().infer_meta_) {
276+
if (infershape_ctx_->HasRuntimeAttributes() == false) {
277+
compat_infermeta_ctx_ = paddle::framework::BuildInferMetaContext(
278+
infershape_ctx_.get(), op_name);
279+
can_use_infermeta_ctx_ = true;
280+
}
281+
}
259282
}
260283

261284
std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
@@ -267,6 +290,10 @@ std::shared_ptr<RuntimeInferShapeContext> Instruction::InnerInferShapeContext()
267290
return infershape_ctx_;
268291
}
269292

293+
const phi::InferMetaContext* Instruction::InnerCompatInferMetaContext() const {
294+
return &compat_infermeta_ctx_;
295+
}
296+
270297
std::shared_ptr<ExecutionContext> Instruction::InnerExecutionContext() const {
271298
return execution_ctx_;
272299
}

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
#include <unordered_map>
1919
#include <vector>
2020

21+
#include "paddle/fluid/framework/infershape_utils.h"
2122
#include "paddle/fluid/framework/operator.h"
2223
#include "paddle/fluid/framework/variable_helper.h"
2324
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
2425
#include "paddle/fluid/platform/device_event_base.h"
2526
#include "paddle/fluid/platform/event.h"
27+
#include "paddle/phi/core/infermeta_utils.h"
2628
#include "paddle/phi/core/utils/rw_lock.h"
2729

2830
#define SCOPE_VARS_READER_LOCK AutoRDLock auto_lock(&vars_lock_);
@@ -262,16 +264,20 @@ class Instruction {
262264
const std::vector<size_t>& GCCheckVars() const;
263265

264266
void ResetContext(const VariableValueMap& in_vars,
265-
const VariableValueMap& out_vars);
267+
const VariableValueMap& out_vars,
268+
const std::string& op_name);
266269

267270
void ResetContextWithScope(const VariableValueMap& in_vars,
268271
const VariableValueMap& out_vars,
269-
const framework::Scope& scope);
272+
const framework::Scope& scope,
273+
const std::string& op_name);
270274

271275
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
272276

273277
std::shared_ptr<RuntimeInferShapeContext> InnerInferShapeContext() const;
274278

279+
const phi::InferMetaContext* InnerCompatInferMetaContext() const;
280+
275281
std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
276282

277283
const platform::DeviceContext& DeviceContext() const;
@@ -290,6 +296,8 @@ class Instruction {
290296

291297
const OpFuncNode* OpFunc() const { return &op_func_node_; }
292298

299+
bool can_use_infermeta_ctx_ = false;
300+
293301
private:
294302
bool is_artificial_; // Instruction is artificial means that it is only used
295303
// to assist scheduling and no need to be executed.
@@ -307,6 +315,7 @@ class Instruction {
307315

308316
std::shared_ptr<RuntimeContext> runtime_ctx_;
309317
std::shared_ptr<RuntimeInferShapeContext> infershape_ctx_;
318+
paddle::framework::CompatInferMetaContext compat_infermeta_ctx_;
310319
std::shared_ptr<ExecutionContext> execution_ctx_;
311320

312321
std::vector<size_t> gc_check_vars_;

paddle/fluid/framework/new_executor/program_interpreter.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,10 @@ void ProgramInterpreter::BuildAndCacheInstructionCtx(Instruction* instr_node) {
406406
// in kernel
407407
Scope* local_scope = HasLocalScope() ? var_scope_.GetMutableLocalScope()
408408
: var_scope_.GetMutableScope();
409-
instr_node->ResetContextWithScope(ins_map, outs_map, *local_scope);
409+
instr_node->ResetContextWithScope(
410+
ins_map, outs_map, *local_scope, instr_node->OpBase()->Type());
410411
} else {
411-
instr_node->ResetContext(ins_map, outs_map);
412+
instr_node->ResetContext(ins_map, outs_map, instr_node->OpBase()->Type());
412413
}
413414
}
414415

@@ -881,10 +882,14 @@ void ProgramInterpreter::RunOperator(const Instruction& instr_node) {
881882
// see OperatorWithKernel::RunImpl in operator.cc for why
882883
if (!(op_with_kernel->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
883884
op_with_kernel->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
884-
op_with_kernel->Info().infer_shape_(
885-
instr_node.InnerInferShapeContext().get());
885+
if (instr_node.can_use_infermeta_ctx_) {
886+
op_with_kernel->Info().infer_meta_(const_cast<phi::InferMetaContext*>(
887+
instr_node.InnerCompatInferMetaContext()));
888+
} else {
889+
op_with_kernel->Info().infer_shape_(
890+
instr_node.InnerInferShapeContext().get());
891+
}
886892
}
887-
infershape_event.End();
888893
platform::RecordOpInfoSupplement(op->Type(),
889894
op->Attrs(),
890895
*(instr_node.InnerInferShapeContext()),

paddle/fluid/framework/op_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class OpInfo {
4848
OpAttrChecker* checker_{nullptr};
4949
InferVarTypeFN infer_var_type_;
5050
InferShapeFN infer_shape_;
51+
InferMetaFN infer_meta_;
5152
InferInplaceOpFN infer_inplace_;
5253
InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;
5354
DygraphGradOpMakerFN dygraph_grad_op_maker_;

paddle/fluid/framework/operator.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,28 @@ RuntimeInferShapeContext::GetPhiDefaultKernelSignature() const {
607607

608608
void RuntimeInferShapeContext::SetSkipLoD(bool skip) { can_skip_lod_ = skip; }
609609

610+
bool RuntimeInferShapeContext::HasRuntimeAttributes() const {
611+
bool is_runtime = false;
612+
if (phi::DefaultKernelSignatureMap::Instance().Has(op_.Type())) {
613+
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
614+
GetPhiDefaultKernelSignature()->name);
615+
if (!phi_kernels.empty()) {
616+
const auto& args_def = phi_kernels.cbegin()->second.args_def();
617+
const auto& attr_defs = args_def.attribute_defs();
618+
for (size_t i = 0; i < attr_defs.size(); ++i) {
619+
if (attr_defs[i].type_index == phi::AttributeType::SCALAR ||
620+
attr_defs[i].type_index == phi::AttributeType::INT_ARRAY) {
621+
is_runtime = true;
622+
break;
623+
}
624+
}
625+
}
626+
} else {
627+
is_runtime = true;
628+
}
629+
return is_runtime;
630+
}
631+
610632
std::vector<LoD> RuntimeInferShapeContext::GetOutputsLod(
611633
const std::string& out) const {
612634
auto out_it = ctx_.outputs.find(out);

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
233233

234234
std::vector<DDim> GetOutputsDim(const std::string& name) const;
235235

236+
bool HasRuntimeAttributes() const;
237+
236238
protected:
237239
DDim GetDim(Variable* var) const;
238240

paddle/fluid/framework/type_defs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License. */
2525
#include "paddle/fluid/imperative/type_defs.h"
2626

2727
#include "paddle/phi/common/scalar.h"
28+
#include "paddle/phi/core/infermeta_utils.h"
2829
#include "paddle/pir/core/block.h"
2930
#include "paddle/pir/core/value.h"
3031
#include "paddle/utils/blank.h"
@@ -102,6 +103,7 @@ using InferVarTypeFN =
102103
std::function<void(framework::InferVarTypeContext* /*context*/)>;
103104

104105
using InferShapeFN = std::function<void(InferShapeContext*)>;
106+
using InferMetaFN = std::function<void(phi::InferMetaContext*)>;
105107

106108
using InplacePair = std::unordered_map<std::string, std::string>;
107109
using InferInplaceOpFN = std::function<InplacePair(bool /*use_cuda*/)>;

0 commit comments

Comments
 (0)