@@ -142,6 +142,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
142
142
auto iter = op_nodes.cbegin ();
143
143
auto *block_desc = (*iter)->Op ()->Block ();
144
144
145
+ // Process multiple conv2d_fusion shares weight.
146
+ std::unordered_set<std::string> weights_shape_nhwc;
147
+
148
+ // Used to control the insertion of transfer_layout op.
145
149
std::unordered_set<ir::Node *> vars_shape_nhwc;
146
150
147
151
// Only support conv2d_fusion now.
@@ -157,6 +161,9 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
157
161
constexpr int NHWC_ALIGNMENT = 8 ;
158
162
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
159
163
for (const auto &filter_name : filter_names) {
164
+ if (weights_shape_nhwc.count (filter_name)) {
165
+ continue ;
166
+ }
160
167
auto *filter_var = scope->FindLocalVar (filter_name);
161
168
const auto &filter_tensor = filter_var->Get <phi::DenseTensor>();
162
169
CHECK_EQ (filter_tensor.dims ().size () == 4UL , true );
@@ -206,27 +213,28 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
206
213
// transfer weights
207
214
auto filter_names = op_desc->Input (" Filter" );
208
215
for (const auto &filter_name : filter_names) {
209
- auto *filter_var = scope->FindLocalVar (filter_name);
210
- auto *filter_tensor = filter_var->GetMutable <phi::DenseTensor>();
211
- phi::DenseTensor temp_tensor = *filter_tensor;
212
- filter_tensor->clear ();
213
-
214
- framework::TransDataLayout (phi::DataLayout::kNCHW ,
215
- phi::DataLayout::kNHWC ,
216
- phi::CPUPlace{},
217
- temp_tensor,
218
- filter_tensor);
219
- }
220
- auto op_inputs = op_node->inputs ;
221
- for (auto *in_var_node : op_inputs) {
222
- CHECK_EQ (in_var_node->IsVar (), true );
223
- if (in_var_node->Var ()->Persistable ()) {
224
- if (std::find (filter_names.cbegin (),
225
- filter_names.cend (),
226
- in_var_node->Var ()->Name ()) != filter_names.cend ()) {
227
- auto from_shape = in_var_node->Var ()->GetShape ();
228
- in_var_node->Var ()->SetShape (
229
- {from_shape[0 ], from_shape[2 ], from_shape[3 ], from_shape[1 ]});
216
+ if (weights_shape_nhwc.count (filter_name) == 0 ) {
217
+ weights_shape_nhwc.insert (filter_name);
218
+ auto *filter_var = scope->FindLocalVar (filter_name);
219
+ auto *filter_tensor = filter_var->GetMutable <phi::DenseTensor>();
220
+ phi::DenseTensor temp_tensor;
221
+
222
+ framework::TransDataLayout (phi::DataLayout::kNCHW ,
223
+ phi::DataLayout::kNHWC ,
224
+ phi::CPUPlace{},
225
+ *filter_tensor,
226
+ &temp_tensor);
227
+ *filter_tensor = temp_tensor;
228
+
229
+ auto op_inputs = op_node->inputs ;
230
+ for (auto *in_var_node : op_inputs) {
231
+ CHECK_EQ (in_var_node->IsVar (), true );
232
+ if (in_var_node->Var ()->Persistable () &&
233
+ in_var_node->Var ()->Name () == filter_name) {
234
+ auto from_shape = in_var_node->Var ()->GetShape ();
235
+ in_var_node->Var ()->SetShape (
236
+ {from_shape[0 ], from_shape[2 ], from_shape[3 ], from_shape[1 ]});
237
+ }
230
238
}
231
239
}
232
240
}
0 commit comments