Skip to content

[Inference] config.delete_pass api can delete inplace_pass and fix a bug in replace_fetch_with_shadow_output_pass #68837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,14 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
sub_scope_);
basic_pass_pm.AddPass(std::move(dead_code_elimination_pass));
}
basic_pass_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
auto replace_fetch_with_shadow_output_pass =
::pir::CreateReplaceFetchWithShadowOutputPass();
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
replace_fetch_with_shadow_output_pass->name()) ==
config_.deleted_passes_.end()) {
basic_pass_pm.AddPass(std::move(replace_fetch_with_shadow_output_pass));
}
if (!config_.glog_info_disabled()) {
basic_pass_pm.EnablePrintStatistics();
}
Expand All @@ -1048,10 +1055,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {

::pir::PassManager lowered_pm(::pir::IrContext::Instance(), 3);
auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass();
remove_shadow_feed_pass->Set("used_for_inference", new bool(true));
lowered_pm.AddPass(std::move(remove_shadow_feed_pass));
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
remove_shadow_feed_pass->name()) ==
config_.deleted_passes_.end()) {
remove_shadow_feed_pass->Set("used_for_inference", new bool(true));
lowered_pm.AddPass(std::move(remove_shadow_feed_pass));
}
if (FLAGS_pir_apply_inplace_pass) {
lowered_pm.AddPass(::pir::CreateInplacePass());
auto inplace_pass = ::pir::CreateInplacePass();
if (std::find(config_.deleted_passes_.begin(),
config_.deleted_passes_.end(),
inplace_pass->name()) == config_.deleted_passes_.end()) {
lowered_pm.AddPass(std::move(inplace_pass));
}
}
if (!config_.glog_info_disabled()) {
lowered_pm.EnablePrintStatistics();
Expand Down Expand Up @@ -1080,6 +1097,8 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
std::string fetch_name =
op->attribute("name").dyn_cast<pir::StrAttribute>().AsString();
idx2fetches_[idx] = fetch_name;
fetch_name2shapes_[fetch_name] =
pir::GetShapeFromValue(op->operand_source(0));
}
} else if (op->isa<paddle::dialect::DataOp>() ||
op->isa<paddle::dialect::FeedOp>()) {
Expand All @@ -1092,6 +1111,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
feed_names_[data_name] = feed_idx;
feed_idx++;
pir_feeds_.emplace_back(op);
feed_name2shapes_[data_name] = pir::GetShapeFromValue(op->result(0));
}

if (op->isa<::pir::ParameterOp>()) {
Expand Down Expand Up @@ -2535,6 +2555,9 @@ std::vector<std::string> AnalysisPredictor::GetInputNames() {

std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetInputTensorShape() {
if (load_pir_model_) {
return feed_name2shapes_;
}
std::map<std::string, std::vector<int64_t>> input_shapes;
std::vector<std::string> names = GetInputNames();
for (std::string const &name : names) {
Expand Down Expand Up @@ -2594,6 +2617,9 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {

std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetOutputTensorShape() {
if (load_pir_model_) {
return fetch_name2shapes_;
}
std::map<std::string, std::vector<int64_t>> output_shapes;
std::vector<std::string> names = GetOutputNames();
for (std::string const &name : names) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,9 +578,11 @@ class AnalysisPredictor : public PaddlePredictor {
std::map<std::string, size_t> feed_names_;
// Sorted according to the idx.
std::map<size_t, std::string> idx2feeds_;
std::map<std::string, std::vector<int64_t>> feed_name2shapes_;
std::vector<framework::OpDesc *> fetches_;
std::vector<pir::Operation *> pir_fetches_;
std::map<size_t, std::string> idx2fetches_;
std::map<std::string, std::vector<int64_t>> fetch_name2shapes_;

phi::DataType model_precision_{phi::DataType::FLOAT32};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
#include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/utils/general_functions.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"

Expand All @@ -28,9 +30,20 @@ class ReplaceFetchWithShadowOutputPattern
bool MatchAndRewrite(
paddle::dialect::FetchOp op,
pir::PatternRewriter& rewriter) const override { // NOLINT
rewriter.Build<pir::ShadowOutputOp>(
op->operand_source(0),
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
// for pd_op.data -> value -> pd_op.fetch, we insert pd_op.scale before
// pd_op.fetch to solve error likes [what(): (NotFound) Variable 'xxx' is
// not found in scope.]
if (pir::GetDefiningOpForInput(op, 0)->HasAttribute("name")) {
auto scale_op = rewriter.Build<paddle::dialect::ScaleOp>(
op->operand_source(0), 1.0, 0.0, true);
rewriter.Build<pir::ShadowOutputOp>(
scale_op->result(0),
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
} else {
rewriter.Build<pir::ShadowOutputOp>(
op->operand_source(0),
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
}
rewriter.EraseOp(op);
return true;
}
Expand Down