Skip to content

Commit af9b069

Browse files
[OneDNN][PIR] Add depthwise_conv_onednn_pass (#63051)
* first commit of depthwise conv pass * style fix * add copy_onnx * add other create to test * check if onednn pass not register * add ifdef PADDLE_WITH_DNNL * add WITH_MKLDNN * fix style bug * add PADDLE_WITH_DNNL * add condition in onnx * SKIP WIN32 CI * name change mkl to onednn * change name * use python ut for depthwise conv * delete skipif * Rename test_depthwise_conv_onednn_pass.py to test_pir_depthwise_conv_onednn_pass.py
1 parent e7a515b commit af9b069

File tree

5 files changed

+208
-0
lines changed

5 files changed

+208
-0
lines changed

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ const std::vector<std::string> kPirXpuPasses{// Functional pass
621621
"add_layernorm_xpu_fuse_pass"};
622622

623623
const std::vector<std::string> kPirMkldnnPasses{
624+
"depthwise_conv_onednn_pass",
624625
"squeeze_transpose_onednn_fuse_pass",
625626
"conv2d_bias_fuse_pass",
626627
"conv2d_transpose_bias_fuse_pass",
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.h"
16+
17+
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
18+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
19+
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
20+
21+
#include "paddle/pir/include/pass/pass.h"
22+
#include "paddle/pir/include/pass/pass_registry.h"
23+
24+
namespace {
25+
26+
class DepthwiseConvPattern : public paddle::drr::DrrPatternBase {
27+
private:
28+
std::string depthwise_conv_name_;
29+
30+
public:
31+
explicit DepthwiseConvPattern(const std::string &conv_name)
32+
: depthwise_conv_name_(conv_name) {}
33+
34+
std::string name() const override { return "DepthwiseConvPattern"; }
35+
36+
uint32_t benefit() const override { return 2; }
37+
38+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
39+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
40+
41+
const auto &depthwise_conv =
42+
pat.Op(depthwise_conv_name_,
43+
{{"strides", pat.Attr("strides")},
44+
{"paddings", pat.Attr("paddings")},
45+
{"padding_algorithm", pat.Attr("padding_algorithm")},
46+
{"dilations", pat.Attr("dilations")},
47+
{"groups", pat.Attr("groups")},
48+
{"data_format", pat.Attr("data_format")}});
49+
50+
depthwise_conv({&pat.Tensor("input"), &pat.Tensor("filter")},
51+
{&pat.Tensor("conv_out")});
52+
53+
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
54+
std::set<std::string> padding_algorithm = {"EXPLICIT", "SAME", "VALID"};
55+
std::set<std::string> data_format = {"NCHW", "NHWC", "AnyLayout"};
56+
if (padding_algorithm.count(
57+
match_ctx.Attr<std::string>("padding_algorithm")) == 0 ||
58+
data_format.count(match_ctx.Attr<std::string>("data_format")) == 0 ||
59+
match_ctx.Attr<int>("groups") < 1) {
60+
return false;
61+
}
62+
return true;
63+
});
64+
65+
paddle::drr::ResultPattern res = pat.ResultPattern();
66+
67+
const auto &conv2d =
68+
res.Op(paddle::dialect::Conv2dOp::name(),
69+
{{
70+
{"strides", pat.Attr("strides")},
71+
{"paddings", pat.Attr("paddings")},
72+
{"padding_algorithm", pat.Attr("padding_algorithm")},
73+
{"dilations", pat.Attr("dilations")},
74+
{"groups", pat.Attr("groups")},
75+
{"data_format", pat.Attr("data_format")},
76+
}});
77+
78+
conv2d({&res.Tensor("input"), &res.Tensor("filter")},
79+
{&res.Tensor("conv_out")});
80+
}
81+
};
82+
83+
class DepthwiseConvMKLDNNPass : public pir::PatternRewritePass {
84+
public:
85+
DepthwiseConvMKLDNNPass()
86+
: pir::PatternRewritePass("depthwise_conv_mkldnn_pass", 2) {}
87+
88+
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
89+
pir::RewritePatternSet ps(context);
90+
ps.Add(paddle::drr::Create<DepthwiseConvPattern>(
91+
context, paddle::dialect::DepthwiseConv2dOp::name()));
92+
return ps;
93+
}
94+
};
95+
96+
} // namespace
97+
98+
namespace pir {
99+
100+
std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass() {
101+
// pd_op.depthwise_conv -> pd_op.conv2d
102+
return std::make_unique<DepthwiseConvMKLDNNPass>();
103+
}
104+
105+
} // namespace pir
106+
107+
REGISTER_IR_PASS(depthwise_conv_onednn_pass, DepthwiseConvMKLDNNPass);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include "paddle/pir/include/core/dll_decl.h"
19+
20+
namespace pir {
21+
22+
class Pass;
23+
24+
IR_API std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass();
25+
26+
} // namespace pir

paddle/fluid/pir/transforms/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ USE_PIR_PASS(fused_dot_product_attention_pass);
4141
USE_PIR_PASS(fused_flash_attn_pass);
4242

4343
#ifdef PADDLE_WITH_DNNL
44+
USE_PIR_PASS(depthwise_conv_onednn_pass);
4445
USE_PIR_PASS(squeeze_transpose_onednn_fuse_pass);
4546
USE_PIR_PASS(batch_norm_act_fuse_pass);
4647
USE_PIR_PASS(conv2d_bias_fuse_pass);
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import unittest
15+
16+
import numpy as np
17+
from pass_test import PassTest
18+
19+
import paddle
20+
21+
paddle.enable_static()
22+
23+
24+
class TestConv2dAddFusePass(PassTest):
25+
def is_program_valid(self, program=None):
26+
return True
27+
28+
def build_ir_program(self):
29+
with paddle.pir_utils.IrGuard():
30+
main_prog = paddle.static.Program()
31+
start_prog = paddle.static.Program()
32+
with paddle.pir.core.program_guard(main_prog, start_prog):
33+
x = paddle.static.data(
34+
name='x', shape=[5, 2, 5, 5], dtype='float32'
35+
)
36+
37+
conv2d = paddle.nn.Conv2D(
38+
in_channels=2,
39+
out_channels=2,
40+
kernel_size=[2, 2],
41+
groups=2,
42+
stride=[1, 1],
43+
padding=[1, 1, 1, 1],
44+
dilation=[1, 1],
45+
data_format='NCHW',
46+
bias_attr=False,
47+
)
48+
49+
conv2d_out = conv2d(x)
50+
out = paddle.assign(conv2d_out)
51+
self.pass_list = ['depthwise_conv_onednn_pass']
52+
53+
self.feeds = {
54+
"x": np.random.random((5, 2, 5, 5)).astype("float32"),
55+
}
56+
self.fetch_list = [out]
57+
self.valid_op_map = {
58+
"pd_op.conv2d": 1,
59+
}
60+
return [main_prog, start_prog]
61+
62+
def sample_program(self):
63+
yield self.build_ir_program(), False
64+
65+
def setUp(self):
66+
self.places.append(paddle.CPUPlace())
67+
68+
def test_check_output(self):
69+
self.check_pass_correct()
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)