@@ -1031,7 +1031,14 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
1031
1031
sub_scope_);
1032
1032
basic_pass_pm.AddPass (std::move (dead_code_elimination_pass));
1033
1033
}
1034
- basic_pass_pm.AddPass (::pir::CreateReplaceFetchWithShadowOutputPass ());
1034
+ auto replace_fetch_with_shadow_output_pass =
1035
+ ::pir::CreateReplaceFetchWithShadowOutputPass ();
1036
+ if (std::find (config_.deleted_passes_ .begin (),
1037
+ config_.deleted_passes_ .end (),
1038
+ replace_fetch_with_shadow_output_pass->name ()) ==
1039
+ config_.deleted_passes_ .end ()) {
1040
+ basic_pass_pm.AddPass (std::move (replace_fetch_with_shadow_output_pass));
1041
+ }
1035
1042
if (!config_.glog_info_disabled ()) {
1036
1043
basic_pass_pm.EnablePrintStatistics ();
1037
1044
}
@@ -1090,6 +1097,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
1090
1097
std::string fetch_name =
1091
1098
op->attribute (" name" ).dyn_cast <pir::StrAttribute>().AsString ();
1092
1099
idx2fetches_[idx] = fetch_name;
1100
+ fetch_name2shapes_[fetch_name] = pir::GetShapeFromValue (op->operand_source (0 ));
1093
1101
}
1094
1102
} else if (op->isa <paddle::dialect::DataOp>() ||
1095
1103
op->isa <paddle::dialect::FeedOp>()) {
@@ -1102,6 +1110,7 @@ bool AnalysisPredictor::SaveOrLoadPirParameters(bool for_save) {
1102
1110
feed_names_[data_name] = feed_idx;
1103
1111
feed_idx++;
1104
1112
pir_feeds_.emplace_back (op);
1113
+ feed_name2shapes_[data_name] = pir::GetShapeFromValue (op->result (0 ));
1105
1114
}
1106
1115
1107
1116
if (op->isa <::pir::ParameterOp>()) {
@@ -2545,6 +2554,9 @@ std::vector<std::string> AnalysisPredictor::GetInputNames() {
2545
2554
2546
2555
std::map<std::string, std::vector<int64_t >>
2547
2556
AnalysisPredictor::GetInputTensorShape () {
2557
+ if (load_pir_model_) {
2558
+ return feed_name2shapes_;
2559
+ }
2548
2560
std::map<std::string, std::vector<int64_t >> input_shapes;
2549
2561
std::vector<std::string> names = GetInputNames ();
2550
2562
for (std::string const &name : names) {
@@ -2604,6 +2616,9 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {
2604
2616
2605
2617
std::map<std::string, std::vector<int64_t >>
2606
2618
AnalysisPredictor::GetOutputTensorShape () {
2619
+ if (load_pir_model_) {
2620
+ return fetch_name2shapes_;
2621
+ }
2607
2622
std::map<std::string, std::vector<int64_t >> output_shapes;
2608
2623
std::vector<std::string> names = GetOutputNames ();
2609
2624
for (std::string const &name : names) {
0 commit comments