Skip to content

Commit ec04916

Browse files
jakpiaseAnnaTrainingG
authored andcommitted
Disabled oneDNN reshape1/2 and squeeze1/2 kernels (PaddlePaddle#35781)
* disabled matmul_v2 grad * Revert "disabled matmul_v2 grad" This reverts commit b569bce. * reverted disabling matmul_v2, disabled reshape and squeeze
1 parent 00e5c4b commit ec04916

File tree

2 files changed

+35
-35
lines changed

2 files changed

+35
-35
lines changed

paddle/fluid/operators/reshape_op.cc

+15-15
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
249249
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
250250

251251
#ifdef PADDLE_WITH_MKLDNN
252-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
253-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
254-
framework::DataLayout::kMKLDNN,
255-
framework::LibraryType::kMKLDNN);
256-
}
252+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
253+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
254+
// framework::DataLayout::kMKLDNN,
255+
// framework::LibraryType::kMKLDNN);
256+
// }
257257
#endif
258258
return framework::OpKernelType(input_data_type, ctx.GetPlace());
259259
}
@@ -367,11 +367,11 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
367367
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
368368

369369
#ifdef PADDLE_WITH_MKLDNN
370-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
371-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
372-
framework::DataLayout::kMKLDNN,
373-
framework::LibraryType::kMKLDNN);
374-
}
370+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
371+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
372+
// framework::DataLayout::kMKLDNN,
373+
// framework::LibraryType::kMKLDNN);
374+
// }
375375
#endif
376376
return framework::OpKernelType(input_data_type, ctx.GetPlace());
377377
}
@@ -558,11 +558,11 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
558558
ctx, framework::GradVarName("Out"));
559559

560560
#ifdef PADDLE_WITH_MKLDNN
561-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
562-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
563-
framework::DataLayout::kMKLDNN,
564-
framework::LibraryType::kMKLDNN);
565-
}
561+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
562+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
563+
// framework::DataLayout::kMKLDNN,
564+
// framework::LibraryType::kMKLDNN);
565+
// }
566566
#endif
567567
return framework::OpKernelType(input_data_type, ctx.GetPlace());
568568
}

paddle/fluid/operators/squeeze_op.cc

+20-20
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ class SqueezeOp : public framework::OperatorWithKernel {
114114
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
115115

116116
#ifdef PADDLE_WITH_MKLDNN
117-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
118-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
119-
framework::DataLayout::kMKLDNN,
120-
framework::LibraryType::kMKLDNN);
121-
}
117+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
118+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
119+
// framework::DataLayout::kMKLDNN,
120+
// framework::LibraryType::kMKLDNN);
121+
// }
122122
#endif
123123
return framework::OpKernelType(input_data_type, ctx.GetPlace());
124124
}
@@ -141,11 +141,11 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
141141
ctx, framework::GradVarName("Out"));
142142

143143
#ifdef PADDLE_WITH_MKLDNN
144-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
145-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
146-
framework::DataLayout::kMKLDNN,
147-
framework::LibraryType::kMKLDNN);
148-
}
144+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
145+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
146+
// framework::DataLayout::kMKLDNN,
147+
// framework::LibraryType::kMKLDNN);
148+
// }
149149
#endif
150150
return framework::OpKernelType(input_data_type, ctx.GetPlace());
151151
}
@@ -242,11 +242,11 @@ class Squeeze2Op : public framework::OperatorWithKernel {
242242
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
243243

244244
#ifdef PADDLE_WITH_MKLDNN
245-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
246-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
247-
framework::DataLayout::kMKLDNN,
248-
framework::LibraryType::kMKLDNN);
249-
}
245+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
246+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
247+
// framework::DataLayout::kMKLDNN,
248+
// framework::LibraryType::kMKLDNN);
249+
// }
250250
#endif
251251
return framework::OpKernelType(input_data_type, ctx.GetPlace());
252252
}
@@ -288,11 +288,11 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
288288
ctx, framework::GradVarName("Out"));
289289

290290
#ifdef PADDLE_WITH_MKLDNN
291-
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
292-
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
293-
framework::DataLayout::kMKLDNN,
294-
framework::LibraryType::kMKLDNN);
295-
}
291+
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
292+
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
293+
// framework::DataLayout::kMKLDNN,
294+
// framework::LibraryType::kMKLDNN);
295+
// }
296296
#endif
297297
return framework::OpKernelType(input_data_type, ctx.GetPlace());
298298
}

0 commit comments

Comments
 (0)