Skip to content

Commit 1237dfa

Browse files
authored
Merge pull request #16885 from NHZlX/fix_anakin_subgraph_shufflenet
Fix anakin subgraph shufflenet
2 parents 5d48e9c + e4726a0 commit 1237dfa

11 files changed

+264
-2
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base)
7070
pass_library(runtime_context_cache_pass base)
7171
pass_library(quant_conv2d_dequant_fuse_pass inference)
7272
pass_library(fillconstant_elementwisemul_fuse inference)
73+
pass_library(shuffle_channel_detect_pass inference)
7374

7475
if(ANAKIN_FOUND)
7576
pass_library(simplify_anakin_priorbox_detection_out_pass inference)

paddle/fluid/framework/ir/graph_pattern_detector.cc

+31
Original file line numberDiff line numberDiff line change
@@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
17061706
}
17071707
}
17081708

1709+
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
1710+
auto reshape1_op =
1711+
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
1712+
1713+
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
1714+
->assert_is_op_output("reshape2", "Out")
1715+
->assert_is_op_input("transpose2")
1716+
->AsIntermediate();
1717+
1718+
auto transpose_op =
1719+
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
1720+
1721+
auto transpose_out = pattern->NewNode(transpose_out_repr())
1722+
->assert_is_op_output("transpose2", "Out")
1723+
->assert_is_op_input("reshape2")
1724+
->AsIntermediate();
1725+
1726+
auto reshape2_op =
1727+
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
1728+
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
1729+
->assert_is_op_output("reshape2", "Out")
1730+
->AsOutput();
1731+
1732+
reshape1_op->LinksFrom({reshape1_in});
1733+
reshape1_out->LinksFrom({reshape1_op});
1734+
transpose_op->LinksFrom({reshape1_out});
1735+
transpose_out->LinksFrom({transpose_op});
1736+
reshape2_op->LinksFrom({transpose_out});
1737+
reshape2_out->LinksFrom({reshape2_op});
1738+
}
1739+
17091740
} // namespace ir
17101741
} // namespace framework
17111742
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.h

+15
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
892892
}
893893
};
894894

895+
struct ShuffleChannelPattern : public PatternBase {
896+
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
897+
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
898+
899+
void operator()(PDNode* reshape1_in);
900+
901+
PATTERN_DECL_NODE(reshape1_op);
902+
PATTERN_DECL_NODE(reshape1_out);
903+
904+
PATTERN_DECL_NODE(transpose_op);
905+
PATTERN_DECL_NODE(transpose_out);
906+
PATTERN_DECL_NODE(reshape2_op);
907+
PATTERN_DECL_NODE(reshape2_out);
908+
};
909+
895910
} // namespace patterns
896911

897912
// Link two ir::Nodes from each other.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <string>
16+
17+
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
18+
#include "paddle/fluid/framework/ir/shuffle_channel_detect_pass.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
25+
#define GET_NODES \
26+
GET_IR_NODE(reshape1_op); \
27+
GET_IR_NODE(reshape1_out); \
28+
GET_IR_NODE(transpose_op); \
29+
GET_IR_NODE(transpose_out); \
30+
GET_IR_NODE(reshape2_op); \
31+
GET_IR_NODE(reshape2_out);
32+
33+
void ShuffleChannelDetectPass::ApplyImpl(ir::Graph* graph) const {
34+
const std::string pattern_name = "shufflechannel_pattern";
35+
FusePassBase::Init(pattern_name, graph);
36+
37+
GraphPatternDetector gpd;
38+
auto* x = gpd.mutable_pattern()
39+
->NewNode("x")
40+
->assert_is_op_input("reshape2", "X")
41+
->AsInput();
42+
43+
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
44+
pattern(x);
45+
46+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
47+
Graph* g) {
48+
GET_NODES;
49+
50+
PADDLE_ENFORCE(subgraph.count(x));
51+
auto* input_node = subgraph.at(x);
52+
auto reshape1_desc = reshape1_op->Op();
53+
auto reshape2_desc = reshape2_op->Op();
54+
std::string input_name = input_node->Name();
55+
std::string output_name = reshape2_out->Name();
56+
57+
auto reshape1_shape =
58+
boost::get<std::vector<int>>(reshape1_desc->GetAttr("shape"));
59+
auto reshape2_shape =
60+
boost::get<std::vector<int>>(reshape2_desc->GetAttr("shape"));
61+
62+
int i_c = reshape1_shape[2];
63+
int o_c = reshape2_shape[1];
64+
int group = o_c / i_c;
65+
66+
framework::OpDesc new_op_desc;
67+
new_op_desc.SetType("shuffle_channel");
68+
new_op_desc.SetInput("X", {input_name});
69+
new_op_desc.SetOutput("Out", {output_name});
70+
71+
new_op_desc.SetAttr("group", group);
72+
new_op_desc.Flush();
73+
74+
// Create a new node for the fused op.
75+
auto* new_op = graph->CreateOpNode(&new_op_desc);
76+
77+
IR_NODE_LINK_TO(input_node, new_op);
78+
IR_NODE_LINK_TO(new_op, reshape2_out);
79+
80+
// Delete the unneeded nodes.
81+
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
82+
transpose_out, reshape2_op});
83+
};
84+
85+
gpd(graph, handler);
86+
}
87+
88+
} // namespace ir
89+
} // namespace framework
90+
} // namespace paddle
91+
92+
REGISTER_PASS(shuffle_channel_detect_pass,
93+
paddle::framework::ir::ShuffleChannelDetectPass);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <vector>
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
class ShuffleChannelDetectPass : public FusePassBase {
25+
public:
26+
virtual ~ShuffleChannelDetectPass() {}
27+
28+
protected:
29+
void ApplyImpl(ir::Graph* graph) const override;
30+
};
31+
32+
} // namespace ir
33+
} // namespace framework
34+
} // namespace paddle

paddle/fluid/inference/anakin/convert/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ cc_library(anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc
22
elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc softmax.cc
33
batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc
44
detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc
5-
roi_align.cc helper.cc DEPS anakin_engine framework_proto scope op_registry
6-
gtest)
5+
roi_align.cc shuffle_channel.cc helper.cc DEPS anakin_engine framework_proto
6+
scope op_registry gtest)
77

88
cc_test(test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL)
99
cc_test(test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/anakin/convert/shuffle_channel.h"
16+
#include <algorithm>
17+
#include <string>
18+
#include <vector>
19+
20+
using anakin::PTuple;
21+
22+
namespace paddle {
23+
namespace inference {
24+
namespace anakin {
25+
26+
template <typename TargetT, ::anakin::Precision PrecisionT>
27+
void ShuffleChannelOpConverter<TargetT, PrecisionT>::operator()(
28+
const framework::proto::OpDesc &op, const framework::BlockDesc &block_desc,
29+
const framework::Scope &scope, bool test_mode) {
30+
framework::OpDesc op_desc(op, nullptr);
31+
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
32+
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
33+
34+
auto input = op_desc.Input("X").front();
35+
auto output = op_desc.Output("Out").front();
36+
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
37+
this->engine_->AddOp(op_name, "ShuffleChannel", {input}, {output});
38+
39+
auto group = boost::get<int>(op_desc.GetAttr("group"));
40+
this->engine_->AddOpAttr(op_name, "group", group);
41+
}
42+
43+
} // namespace anakin
44+
} // namespace inference
45+
} // namespace paddle
46+
47+
REGISTER_ANAKIN_OP_CONVERTER(shuffle_channel, ShuffleChannelOpConverter);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace anakin {
22+
23+
template <typename TargetT, ::anakin::Precision PrecisionT>
24+
class ShuffleChannelOpConverter
25+
: public AnakinOpConverter<TargetT, PrecisionT> {
26+
public:
27+
ShuffleChannelOpConverter() = default;
28+
29+
virtual void operator()(const framework::proto::OpDesc &op,
30+
const framework::BlockDesc &block_desc,
31+
const framework::Scope &scope,
32+
bool test_mode) override;
33+
virtual ~ShuffleChannelOpConverter() {}
34+
};
35+
36+
} // namespace anakin
37+
} // namespace inference
38+
} // namespace paddle

paddle/fluid/inference/anakin/op_teller.cc

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct SimpleOpTypeSetTeller : public Teller {
4848
teller_set.insert("affine_channel");
4949
teller_set.insert("relu6");
5050
teller_set.insert("swish");
51+
teller_set.insert("shuffle_channel");
5152
}
5253

5354
bool operator()(const std::string& op_type,

paddle/fluid/inference/api/analysis_predictor.cc

+1
Original file line numberDiff line numberDiff line change
@@ -896,4 +896,5 @@ USE_ANAKIN_CONVERTER(leaky_relu);
896896
USE_ANAKIN_CONVERTER(affine_channel);
897897
USE_ANAKIN_CONVERTER(relu6);
898898
USE_ANAKIN_CONVERTER(swish);
899+
USE_ANAKIN_CONVERTER(shuffle_channel);
899900
#endif

paddle/fluid/inference/api/paddle_pass_builder.cc

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ const std::vector<std::string> kAnakinSubgraphPasses({
7979
"fc_fuse_pass", //
8080
"conv_elementwise_add_fuse_pass", //
8181
"fc_gru_fuse_pass", //
82+
"shuffle_channel_detect_pass", //
8283
"anakin_subgraph_pass", //
8384
"fc_gru_fuse_pass", //
8485
});

0 commit comments

Comments
 (0)