diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 92062fe4ff8100..3208e2558b1718 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1031,7 +1031,14 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { sub_scope_); basic_pass_pm.AddPass(std::move(dead_code_elimination_pass)); } - basic_pass_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); + auto replace_fetch_with_shadow_output_pass = + ::pir::CreateReplaceFetchWithShadowOutputPass(); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + replace_fetch_with_shadow_output_pass->name()) == + config_.deleted_passes_.end()) { + basic_pass_pm.AddPass(std::move(replace_fetch_with_shadow_output_pass)); + } if (!config_.glog_info_disabled()) { basic_pass_pm.EnablePrintStatistics(); } @@ -1048,10 +1055,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { ::pir::PassManager lowered_pm(::pir::IrContext::Instance(), 3); auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass(); - remove_shadow_feed_pass->Set("used_for_inference", new bool(true)); - lowered_pm.AddPass(std::move(remove_shadow_feed_pass)); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + remove_shadow_feed_pass->name()) == + config_.deleted_passes_.end()) { + remove_shadow_feed_pass->Set("used_for_inference", new bool(true)); + lowered_pm.AddPass(std::move(remove_shadow_feed_pass)); + } if (FLAGS_pir_apply_inplace_pass) { - lowered_pm.AddPass(::pir::CreateInplacePass()); + auto inplace_pass = ::pir::CreateInplacePass(); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + inplace_pass->name()) == config_.deleted_passes_.end()) { + lowered_pm.AddPass(std::move(inplace_pass)); + } } if (!config_.glog_info_disabled()) { lowered_pm.EnablePrintStatistics(); @@ -1080,6 +1097,8 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) { std::string fetch_name = op->attribute("name").dyn_cast().AsString(); idx2fetches_[idx] = fetch_name; + fetch_name2shapes_[fetch_name] = + pir::GetShapeFromValue(op->operand_source(0)); } } else if (op->isa() || op->isa()) { @@ -1092,6 +1111,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) { feed_names_[data_name] = feed_idx; feed_idx++; pir_feeds_.emplace_back(op); + feed_name2shapes_[data_name] = pir::GetShapeFromValue(op->result(0)); } if (op->isa<::pir::ParameterOp>()) { @@ -2535,6 +2555,9 @@ std::vector AnalysisPredictor::GetInputNames() { std::map> AnalysisPredictor::GetInputTensorShape() { + if (load_pir_model_) { + return feed_name2shapes_; + } std::map> input_shapes; std::vector names = GetInputNames(); for (std::string const &name : names) { @@ -2594,6 +2617,9 @@ std::vector AnalysisPredictor::GetOutputNames() { std::map> AnalysisPredictor::GetOutputTensorShape() { + if (load_pir_model_) { + return fetch_name2shapes_; + } std::map> output_shapes; std::vector names = GetOutputNames(); for (std::string const &name : names) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 1d99ddfef4ca96..9c81ab85bfdb00 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -578,9 +578,11 @@ class AnalysisPredictor : public PaddlePredictor { std::map feed_names_; // Sorted according to the idx. std::map idx2feeds_; + std::map> feed_name2shapes_; std::vector fetches_; std::vector pir_fetches_; std::map idx2fetches_; + std::map> fetch_name2shapes_; phi::DataType model_precision_{phi::DataType::FLOAT32}; diff --git a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc index 9bb8e539c2def1..0ce10eddef9d94 100644 --- a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc @@ -15,7 +15,9 @@ #include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/ir_context.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -28,9 +30,20 @@ class ReplaceFetchWithShadowOutputPattern bool MatchAndRewrite( paddle::dialect::FetchOp op, pir::PatternRewriter& rewriter) const override { // NOLINT - rewriter.Build( - op->operand_source(0), - op->attributes().at("name").dyn_cast().AsString()); + // for pd_op.data -> value -> pd_op.fetch, we insert pd_op.scale before + // pd_op.fetch to solve error likes [what(): (NotFound) Variable 'xxx' is + // not found in scope.] + if (pir::GetDefiningOpForInput(op, 0)->HasAttribute("name")) { + auto scale_op = rewriter.Build( + op->operand_source(0), 1.0, 0.0, true); + rewriter.Build( + scale_op->result(0), + op->attributes().at("name").dyn_cast().AsString()); + } else { + rewriter.Build( + op->operand_source(0), + op->attributes().at("name").dyn_cast().AsString()); + } rewriter.EraseOp(op); return true; }