Skip to content

Commit 6d9c254

Browse files
committed
fix replace_fetch_with_shadow_output_pass
1 parent 69360ae commit 6d9c254

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

+16-1
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,14 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
10311031
sub_scope_);
10321032
basic_pass_pm.AddPass(std::move(dead_code_elimination_pass));
10331033
}
1034-
basic_pass_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
1034+
auto replace_fetch_with_shadow_output_pass =
1035+
::pir::CreateReplaceFetchWithShadowOutputPass();
1036+
if (std::find(config_.deleted_passes_.begin(),
1037+
config_.deleted_passes_.end(),
1038+
replace_fetch_with_shadow_output_pass->name()) ==
1039+
config_.deleted_passes_.end()) {
1040+
basic_pass_pm.AddPass(std::move(replace_fetch_with_shadow_output_pass));
1041+
}
10351042
if (!config_.glog_info_disabled()) {
10361043
basic_pass_pm.EnablePrintStatistics();
10371044
}
@@ -1090,6 +1097,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
10901097
std::string fetch_name =
10911098
op->attribute("name").dyn_cast<pir::StrAttribute>().AsString();
10921099
idx2fetches_[idx] = fetch_name;
1100+
fetch_name2shapes_[fetch_name] = pir::GetShapeFromValue(op->operand_source(0));
10931101
}
10941102
} else if (op->isa<paddle::dialect::DataOp>() ||
10951103
op->isa<paddle::dialect::FeedOp>()) {
@@ -1102,6 +1110,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
11021110
feed_names_[data_name] = feed_idx;
11031111
feed_idx++;
11041112
pir_feeds_.emplace_back(op);
1113+
feed_name2shapes_[data_name] = pir::GetShapeFromValue(op->result(0));
11051114
}
11061115

11071116
if (op->isa<::pir::ParameterOp>()) {
@@ -2545,6 +2554,9 @@ std::vector<std::string> AnalysisPredictor::GetInputNames() {
25452554

25462555
std::map<std::string, std::vector<int64_t>>
25472556
AnalysisPredictor::GetInputTensorShape() {
2557+
if (load_pir_model_) {
2558+
return feed_name2shapes_;
2559+
}
25482560
std::map<std::string, std::vector<int64_t>> input_shapes;
25492561
std::vector<std::string> names = GetInputNames();
25502562
for (std::string const &name : names) {
@@ -2604,6 +2616,9 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {
26042616

26052617
std::map<std::string, std::vector<int64_t>>
26062618
AnalysisPredictor::GetOutputTensorShape() {
2619+
if (load_pir_model_) {
2620+
return fetch_name2shapes_;
2621+
}
26072622
std::map<std::string, std::vector<int64_t>> output_shapes;
26082623
std::vector<std::string> names = GetOutputNames();
26092624
for (std::string const &name : names) {

paddle/fluid/inference/api/analysis_predictor.h

+2
Original file line numberDiff line numberDiff line change
@@ -578,9 +578,11 @@ class AnalysisPredictor : public PaddlePredictor {
578578
std::map<std::string, size_t> feed_names_;
579579
// Sorted according to the idx.
580580
std::map<size_t, std::string> idx2feeds_;
581+
std::map<std::string, std::vector<int64_t>> feed_name2shapes_;
581582
std::vector<framework::OpDesc *> fetches_;
582583
std::vector<pir::Operation *> pir_fetches_;
583584
std::map<size_t, std::string> idx2fetches_;
585+
std::map<std::string, std::vector<int64_t>> fetch_name2shapes_;
584586

585587
phi::DataType model_precision_{phi::DataType::FLOAT32};
586588

paddle/fluid/pir/transforms/general/inplace_pass.cc

+4
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ bool CanDoInplace(const std::unordered_set<pir::Value>& eager_dels,
8888
return false;
8989
}
9090

91+
// for (auto it = output.use_begin(); it != output.use_end(); ++it) {
92+
// if(it->owner()->isa<pir::ShadowOutputOp>() || )
93+
// }
94+
9195
if (input.type().isa<TensorType>() && output.type().isa<TensorType>()) {
9296
auto input_alloc_tensor_type = input.type().dyn_cast<TensorType>();
9397
auto output_alloc_tensor_type = output.type().dyn_cast<TensorType>();

paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc

+13-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
1818
#include "paddle/fluid/pir/utils/general_functions.h"
1919
#include "paddle/pir/include/core/builtin_op.h"
20+
#include "paddle/pir/include/core/ir_context.h"
2021
#include "paddle/pir/include/pass/pass.h"
2122
#include "paddle/pir/include/pass/pass_registry.h"
2223

@@ -29,13 +30,20 @@ class ReplaceFetchWithShadowOutputPattern
2930
bool MatchAndRewrite(
3031
paddle::dialect::FetchOp op,
3132
pir::PatternRewriter& rewriter) const override { // NOLINT
33+
// for pd_op.data -> value -> pd_op.fetch, we insert pd_op.scale before
34+
// pd_op.fetch to solve error likes [what(): (NotFound) Variable 'xxx' is
35+
// not found in scope.]
3236
if (pir::GetDefiningOpForInput(op, 0)->HasAttribute("name")) {
33-
// DataOp/FeedOp
34-
return false;
37+
auto scale_op = rewriter.Build<paddle::dialect::ScaleOp>(
38+
op->operand_source(0), 1.0, 0.0, true);
39+
rewriter.Build<pir::ShadowOutputOp>(
40+
scale_op->result(0),
41+
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
42+
} else {
43+
rewriter.Build<pir::ShadowOutputOp>(
44+
op->operand_source(0),
45+
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
3546
}
36-
rewriter.Build<pir::ShadowOutputOp>(
37-
op->operand_source(0),
38-
op->attributes().at("name").dyn_cast<pir::StrAttribute>().AsString());
3947
rewriter.EraseOp(op);
4048
return true;
4149
}

0 commit comments

Comments
 (0)