From 647f638fe8120cd5adc574b60bd790d981483cdf Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 21 Oct 2024 03:17:53 +0000 Subject: [PATCH 1/4] config.delete_pass api can delete inplace_pass --- paddle/fluid/inference/api/analysis_predictor.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 92062fe4ff8100..7cf08450474be4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1048,10 +1048,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { ::pir::PassManager lowered_pm(::pir::IrContext::Instance(), 3); auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass(); - remove_shadow_feed_pass->Set("used_for_inference", new bool(true)); - lowered_pm.AddPass(std::move(remove_shadow_feed_pass)); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + remove_shadow_feed_pass->name()) == + config_.deleted_passes_.end()) { + remove_shadow_feed_pass->Set("used_for_inference", new bool(true)); + lowered_pm.AddPass(std::move(remove_shadow_feed_pass)); + } if (FLAGS_pir_apply_inplace_pass) { - lowered_pm.AddPass(::pir::CreateInplacePass()); + auto inplace_pass = ::pir::CreateInplacePass(); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + inplace_pass->name()) == config_.deleted_passes_.end()) { + lowered_pm.AddPass(std::move(inplace_pass)); + } } if (!config_.glog_info_disabled()) { lowered_pm.EnablePrintStatistics(); From 69360ae721b4ebfc9ad84f92e5cde870108396af Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 21 Oct 2024 03:40:46 +0000 Subject: [PATCH 2/4] fix replace_fetch_with_shadow_output_pass --- .../general/replace_fetch_with_shadow_output_pass.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc index 9bb8e539c2def1..dd1d7173ee9a5a 100644 --- a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -28,6 +29,10 @@ class ReplaceFetchWithShadowOutputPattern bool MatchAndRewrite( paddle::dialect::FetchOp op, pir::PatternRewriter& rewriter) const override { // NOLINT + if (pir::GetDefiningOpForInput(op, 0)->HasAttribute("name")) { + // DataOp/FeedOp + return false; + } rewriter.Build( op->operand_source(0), op->attributes().at("name").dyn_cast().AsString()); From 6d9c25478a7e52ea9b72ef0647cf8e78586ba775 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 21 Oct 2024 11:24:42 +0000 Subject: [PATCH 3/4] fix replace_fetch_with_shadow_output_pass --- .../fluid/inference/api/analysis_predictor.cc | 17 ++++++++++++++++- .../fluid/inference/api/analysis_predictor.h | 2 ++ .../pir/transforms/general/inplace_pass.cc | 4 ++++ .../replace_fetch_with_shadow_output_pass.cc | 18 +++++++++++++----- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7cf08450474be4..323ef2c69ff973 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1031,7 +1031,14 @@ void AnalysisPredictor::OptimizeInferencePirProgram() { sub_scope_); basic_pass_pm.AddPass(std::move(dead_code_elimination_pass)); } - basic_pass_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); + auto replace_fetch_with_shadow_output_pass = + ::pir::CreateReplaceFetchWithShadowOutputPass(); + if (std::find(config_.deleted_passes_.begin(), + config_.deleted_passes_.end(), + replace_fetch_with_shadow_output_pass->name()) == + config_.deleted_passes_.end()) { + basic_pass_pm.AddPass(std::move(replace_fetch_with_shadow_output_pass)); + } if (!config_.glog_info_disabled()) { basic_pass_pm.EnablePrintStatistics(); } @@ -1090,6 +1097,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) { std::string fetch_name = op->attribute("name").dyn_cast().AsString(); idx2fetches_[idx] = fetch_name; + fetch_name2shapes_[fetch_name] = pir::GetShapeFromValue(op->operand_source(0)); } } else if (op->isa() || op->isa()) { @@ -1102,6 +1110,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) { feed_names_[data_name] = feed_idx; feed_idx++; pir_feeds_.emplace_back(op); + feed_name2shapes_[data_name] = pir::GetShapeFromValue(op->result(0)); } if (op->isa<::pir::ParameterOp>()) { @@ -2545,6 +2554,9 @@ std::vector AnalysisPredictor::GetInputNames() { std::map> AnalysisPredictor::GetInputTensorShape() { + if (load_pir_model_) { + return feed_name2shapes_; + } std::map> input_shapes; std::vector names = GetInputNames(); for (std::string const &name : names) { @@ -2604,6 +2616,9 @@ std::vector AnalysisPredictor::GetOutputNames() { std::map> AnalysisPredictor::GetOutputTensorShape() { + if (load_pir_model_) { + return fetch_name2shapes_; + } std::map> output_shapes; std::vector names = GetOutputNames(); for (std::string const &name : names) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 1d99ddfef4ca96..9c81ab85bfdb00 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -578,9 +578,11 @@ class AnalysisPredictor : public PaddlePredictor { std::map feed_names_; // Sorted according to the idx. std::map idx2feeds_; + std::map> feed_name2shapes_; std::vector fetches_; std::vector pir_fetches_; std::map idx2fetches_; + std::map> fetch_name2shapes_; phi::DataType model_precision_{phi::DataType::FLOAT32}; diff --git a/paddle/fluid/pir/transforms/general/inplace_pass.cc b/paddle/fluid/pir/transforms/general/inplace_pass.cc index 23506b1d456914..8baed9f9f8ddea 100644 --- a/paddle/fluid/pir/transforms/general/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/general/inplace_pass.cc @@ -88,6 +88,10 @@ bool CanDoInplace(const std::unordered_set& eager_dels, return false; } + // for (auto it = output.use_begin(); it != output.use_end(); ++it) { + // if(it->owner()->isa() || ) + // } + if (input.type().isa() && output.type().isa()) { auto input_alloc_tensor_type = input.type().dyn_cast(); auto output_alloc_tensor_type = output.type().dyn_cast(); diff --git a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc index dd1d7173ee9a5a..0ce10eddef9d94 100644 --- a/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/ir_context.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -29,13 +30,20 @@ class ReplaceFetchWithShadowOutputPattern bool MatchAndRewrite( paddle::dialect::FetchOp op, pir::PatternRewriter& rewriter) const override { // NOLINT + // for pd_op.data -> value -> pd_op.fetch, we insert pd_op.scale before + // pd_op.fetch to solve error likes [what(): (NotFound) Variable 'xxx' is + // not found in scope.] if (pir::GetDefiningOpForInput(op, 0)->HasAttribute("name")) { - // DataOp/FeedOp - return false; + auto scale_op = rewriter.Build( + op->operand_source(0), 1.0, 0.0, true); + rewriter.Build( + scale_op->result(0), + op->attributes().at("name").dyn_cast().AsString()); + } else { + rewriter.Build( + op->operand_source(0), + op->attributes().at("name").dyn_cast().AsString()); } - rewriter.Build( - op->operand_source(0), - op->attributes().at("name").dyn_cast().AsString()); rewriter.EraseOp(op); return true; } From f28eafc82086f0078419315c6fe693e09f7f9c32 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 21 Oct 2024 11:39:51 +0000 Subject: [PATCH 4/4] fix codestyle --- paddle/fluid/inference/api/analysis_predictor.cc | 3 ++- paddle/fluid/pir/transforms/general/inplace_pass.cc | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 323ef2c69ff973..3208e2558b1718 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1097,7 +1097,8 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) { std::string fetch_name = op->attribute("name").dyn_cast().AsString(); idx2fetches_[idx] = fetch_name; - fetch_name2shapes_[fetch_name] = pir::GetShapeFromValue(op->operand_source(0)); + fetch_name2shapes_[fetch_name] = + pir::GetShapeFromValue(op->operand_source(0)); } } else if (op->isa() || op->isa()) { diff --git a/paddle/fluid/pir/transforms/general/inplace_pass.cc b/paddle/fluid/pir/transforms/general/inplace_pass.cc index 8baed9f9f8ddea..23506b1d456914 100644 --- a/paddle/fluid/pir/transforms/general/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/general/inplace_pass.cc @@ -88,10 +88,6 @@ bool CanDoInplace(const std::unordered_set& eager_dels, return false; } - // for (auto it = output.use_begin(); it != output.use_end(); ++it) { - // if(it->owner()->isa() || ) - // } - if (input.type().isa() && output.type().isa()) { auto input_alloc_tensor_type = input.type().dyn_cast(); auto output_alloc_tensor_type = output.type().dyn_cast();