Skip to content

Commit 9f80c7f

Browse files
authored
[Inference] add add_shadow_output_after_dead_parameter_pass (#68476)
* add add_shadow_output_after_dead_parameter_pass * fix * fix * fix ci
1 parent d5fd91c commit 9f80c7f

File tree

5 files changed

+107
-1
lines changed

5 files changed

+107
-1
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) {
581581

582582
const std::vector<std::string> kPirGpuPasses{
583583
// Functional pass
584+
"add_shadow_output_after_dead_parameter_pass",
584585
"delete_quant_dequant_linear_op_pass",
585586
"delete_weight_dequant_linear_op_pass",
586587
"map_op_to_another_pass",
@@ -609,6 +610,7 @@ const std::vector<std::string> kPirGpuPasses{
609610

610611
const std::vector<std::string> kPirXpuPasses{
611612
// Functional pass
613+
"add_shadow_output_after_dead_parameter_pass",
612614
"delete_quant_dequant_linear_op_pass",
613615
"delete_weight_dequant_linear_op_pass",
614616
"map_op_to_another_pass",
@@ -622,7 +624,8 @@ const std::vector<std::string> kPirXpuPasses{
622624
"fc_xpu_fuse_pass"};
623625

624626
const std::vector<std::string> kPirMkldnnPasses {
625-
"delete_quant_dequant_linear_op_pass", //
627+
"add_shadow_output_after_dead_parameter_pass",
628+
"delete_quant_dequant_linear_op_pass", //
626629
"delete_weight_dequant_linear_op_pass", //
627630
"depthwise_conv_onednn_pass", //
628631
"squeeze_transpose_onednn_fuse_pass", //
@@ -663,13 +666,15 @@ const std::vector<std::string> kPirMkldnnPasses {
663666
};
664667

665668
const std::vector<std::string> kPirMkldnnBf16Passes{
669+
"add_shadow_output_after_dead_parameter_pass",
666670
"cpu_bfloat16_placement_pass",
667671
"cpu_bfloat16_pass",
668672
"cpu_bfloat16_type_placement_pass",
669673
"cpu_bf16_quantize_squash_pass",
670674
};
671675

672676
const std::vector<std::string> kPirCpuPasses{
677+
"add_shadow_output_after_dead_parameter_pass",
673678
"delete_quant_dequant_linear_op_pass",
674679
"delete_weight_dequant_linear_op_pass"};
675680

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/transforms/general/add_shadow_output_after_dead_parameter_pass.h"
16+
17+
#include "paddle/pir/include/core/builtin_op.h"
18+
#include "paddle/pir/include/pass/pass.h"
19+
#include "paddle/pir/include/pass/pass_registry.h"
20+
21+
namespace {
22+
23+
class AddShadowOutputAfterDeadParameterPattern
24+
: public pir::OpRewritePattern<pir::ParameterOp> {
25+
public:
26+
using pir::OpRewritePattern<pir::ParameterOp>::OpRewritePattern;
27+
bool MatchAndRewrite(
28+
pir::ParameterOp op,
29+
pir::PatternRewriter& rewriter) const override { // NOLINT
30+
if (!op->use_empty()) {
31+
return false;
32+
}
33+
rewriter.SetInsertionPointToBlockEnd(op->GetParent());
34+
rewriter.Build<pir::ShadowOutputOp>(op->result(0), op.param_name());
35+
return true;
36+
}
37+
};
38+
39+
class AddShadowOutputAfterDeadParameterPass : public pir::PatternRewritePass {
40+
public:
41+
AddShadowOutputAfterDeadParameterPass()
42+
: pir::PatternRewritePass("add_shadow_output_after_dead_parameter_pass",
43+
0) {}
44+
45+
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
46+
pir::RewritePatternSet ps(context);
47+
ps.Add<AddShadowOutputAfterDeadParameterPattern>(context);
48+
return ps;
49+
}
50+
51+
bool CanApplyOn(pir::Operation* op) const override {
52+
return op->isa<::pir::ModuleOp>() && op->num_regions() > 0;
53+
}
54+
55+
private:
56+
pir::FrozenRewritePatternSet patterns_;
57+
};
58+
59+
} // namespace
60+
61+
namespace pir {
62+
63+
std::unique_ptr<pir::Pass> CreateAddShadowOutputAfterDeadParameterPass() {
64+
return std::make_unique<AddShadowOutputAfterDeadParameterPass>();
65+
}
66+
67+
} // namespace pir
68+
69+
REGISTER_IR_PASS(add_shadow_output_after_dead_parameter_pass,
70+
AddShadowOutputAfterDeadParameterPass);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include "paddle/pir/include/core/dll_decl.h"
19+
20+
namespace pir {
21+
22+
class Pass;
23+
24+
IR_API std::unique_ptr<Pass> CreateAddShadowOutputAfterDeadParameterPass();
25+
26+
} // namespace pir

paddle/fluid/pir/transforms/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ USE_PIR_PASS(transfer_layout_pass);
4848
USE_PIR_PASS(fused_rotary_position_embedding_pass);
4949
USE_PIR_PASS(horizontal_fuse_pass);
5050
USE_PIR_PASS(common_subexpression_elimination_pass);
51+
USE_PIR_PASS(add_shadow_output_after_dead_parameter_pass);
5152

5253
#ifdef PADDLE_WITH_DNNL
5354
USE_PIR_PASS(depthwise_conv_onednn_pass);

test/ir/inference/test_inference_predictor_run_pir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(self):
2727
super().__init__()
2828
self.fc1 = paddle.nn.Linear(4, 4)
2929
self.fc2 = paddle.nn.Linear(4, 4)
30+
self.register_buffer("buffer", paddle.randn([5, 1]))
3031

3132
def forward(self, x1, x2):
3233
y1 = self.fc1(x1)
@@ -122,6 +123,9 @@ def init_pir_predictor(self):
122123
# config.enable_memory_optim()
123124
config.enable_new_executor()
124125
config.enable_new_ir()
126+
config.switch_ir_debug(
127+
True, ['add_shadow_output_after_dead_parameter_pass']
128+
)
125129
predictor = create_predictor(config)
126130
return predictor
127131

0 commit comments

Comments
 (0)