Skip to content

Commit ae60105

Browse files
authored
process multiple conv2d_fusion shares weight (#51068)
1 parent 4652bee commit ae60105

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc

+29-21
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
142142
auto iter = op_nodes.cbegin();
143143
auto *block_desc = (*iter)->Op()->Block();
144144

145+
// Process multiple conv2d_fusion shares weight.
146+
std::unordered_set<std::string> weights_shape_nhwc;
147+
148+
// Used to control the insertion of transfer_layout op.
145149
std::unordered_set<ir::Node *> vars_shape_nhwc;
146150

147151
// Only support conv2d_fusion now.
@@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
157161
constexpr int NHWC_ALIGNMENT = 8;
158162
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
159163
for (const auto &filter_name : filter_names) {
164+
if (weights_shape_nhwc.count(filter_name)) {
165+
continue;
166+
}
160167
auto *filter_var = scope->FindLocalVar(filter_name);
161168
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
162169
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
@@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
206213
// transfer weights
207214
auto filter_names = op_desc->Input("Filter");
208215
for (const auto &filter_name : filter_names) {
209-
auto *filter_var = scope->FindLocalVar(filter_name);
210-
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
211-
phi::DenseTensor temp_tensor = *filter_tensor;
212-
filter_tensor->clear();
213-
214-
framework::TransDataLayout(phi::DataLayout::kNCHW,
215-
phi::DataLayout::kNHWC,
216-
phi::CPUPlace{},
217-
temp_tensor,
218-
filter_tensor);
219-
}
220-
auto op_inputs = op_node->inputs;
221-
for (auto *in_var_node : op_inputs) {
222-
CHECK_EQ(in_var_node->IsVar(), true);
223-
if (in_var_node->Var()->Persistable()) {
224-
if (std::find(filter_names.cbegin(),
225-
filter_names.cend(),
226-
in_var_node->Var()->Name()) != filter_names.cend()) {
227-
auto from_shape = in_var_node->Var()->GetShape();
228-
in_var_node->Var()->SetShape(
229-
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
216+
if (weights_shape_nhwc.count(filter_name) == 0) {
217+
weights_shape_nhwc.insert(filter_name);
218+
auto *filter_var = scope->FindLocalVar(filter_name);
219+
auto *filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
220+
phi::DenseTensor temp_tensor;
221+
222+
framework::TransDataLayout(phi::DataLayout::kNCHW,
223+
phi::DataLayout::kNHWC,
224+
phi::CPUPlace{},
225+
*filter_tensor,
226+
&temp_tensor);
227+
*filter_tensor = temp_tensor;
228+
229+
auto op_inputs = op_node->inputs;
230+
for (auto *in_var_node : op_inputs) {
231+
CHECK_EQ(in_var_node->IsVar(), true);
232+
if (in_var_node->Var()->Persistable() &&
233+
in_var_node->Var()->Name() == filter_name) {
234+
auto from_shape = in_var_node->Var()->GetShape();
235+
in_var_node->Var()->SetShape(
236+
{from_shape[0], from_shape[2], from_shape[3], from_shape[1]});
237+
}
230238
}
231239
}
232240
}

paddle/fluid/inference/api/paddle_pass_builder.cc

+3-4
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,9 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
270270
"conv_elementwise_add_fuse_pass", //
271271
#endif //
272272
"transpose_flatten_concat_fuse_pass", //
273-
// TODO(liuyuanle): rewrite this pass with new logic
274-
// "conv2d_fusion_layout_transfer_pass", //
275-
"auto_mixed_precision_pass", //
276-
"inplace_op_var_pass", // should be the last pass.
273+
"conv2d_fusion_layout_transfer_pass", //
274+
"auto_mixed_precision_pass", //
275+
"inplace_op_var_pass", // should be the last pass.
277276
});
278277

279278
use_gpu_ = true;

0 commit comments

Comments
 (0)