Skip to content

Commit 4a4f3f8

Browse files
authored
migrate convs (#47658)
1 parent ca4bed7 commit 4a4f3f8

File tree

2 files changed

+57
-254
lines changed

2 files changed

+57
-254
lines changed

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 0 additions & 254 deletions
Original file line numberDiff line numberDiff line change
@@ -776,247 +776,6 @@ class ConvMKLDNNHandlerT
776776

777777
} // anonymous namespace
778778

779-
#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \
780-
[&] { \
781-
const auto& __dtype__ = TYPE; \
782-
switch (__dtype__) { \
783-
PD_PRIVATE_CASE_TYPE( \
784-
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
785-
PD_PRIVATE_CASE_TYPE( \
786-
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
787-
default: \
788-
PD_THROW("function " #NAME " is not implemented for data type `", \
789-
__dtype__, \
790-
"`"); \
791-
} \
792-
}()
793-
794-
template <typename T>
795-
class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
796-
public:
797-
void Compute(const framework::ExecutionContext& ctx) const override {
798-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
799-
true,
800-
platform::errors::PreconditionNotMet(
801-
"Operator DNNL Conv must use CPUPlace"));
802-
bool is_INT8 =
803-
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
804-
bool is_BFLOAT16 = ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
805-
auto residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
806-
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
807-
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
808-
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
809-
auto dst_dt = GetDstType(is_INT8,
810-
is_BFLOAT16,
811-
force_fp32_output,
812-
fuse_activation,
813-
fuse_residual_conn,
814-
residual_param);
815-
if (!is_INT8) {
816-
if (dst_dt == dnnl::memory::data_type::f32) {
817-
ComputeFP32<float>(ctx);
818-
} else if (dst_dt == dnnl::memory::data_type::bf16) {
819-
ComputeFP32<platform::bfloat16>(ctx);
820-
}
821-
} else {
822-
if (dst_dt == dnnl::memory::data_type::f32) {
823-
ComputeINT8<float>(ctx);
824-
} else if (dst_dt == dnnl::memory::data_type::u8) {
825-
ComputeINT8<uint8_t>(ctx);
826-
} else if (dst_dt == dnnl::memory::data_type::s8) {
827-
ComputeINT8<int8_t>(ctx);
828-
}
829-
}
830-
}
831-
832-
template <typename T_out>
833-
void ComputeFP32(const framework::ExecutionContext& ctx) const {
834-
auto& dev_ctx =
835-
ctx.template device_context<platform::MKLDNNDeviceContext>();
836-
const auto& mkldnn_engine = dev_ctx.GetEngine();
837-
838-
bool is_test = ctx.Attr<bool>("is_test");
839-
const auto& strides = ctx.Attr<std::vector<int>>("strides");
840-
bool is_conv3d = strides.size() == 3UL;
841-
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
842-
int groups = ctx.Attr<int>("groups");
843-
844-
const auto* input = ctx.Input<phi::DenseTensor>("Input");
845-
const auto* filter = ctx.Input<phi::DenseTensor>("Filter");
846-
const auto* bias =
847-
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
848-
auto* output = ctx.Output<phi::DenseTensor>("Output");
849-
850-
PD_VISIT_FLOAT_AND_INT8_TYPES(
851-
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
852-
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
853-
ctx,
854-
dev_ctx,
855-
mkldnn_engine,
856-
ctx.GetPlace(),
857-
input,
858-
filter,
859-
bias,
860-
output,
861-
ctx.InputName("Input") + ctx.InputName("Filter"));
862-
863-
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
864-
865-
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
866-
filter, groups, is_conv3d, is_test);
867-
868-
std::shared_ptr<dnnl::memory> dst_memory_p;
869-
if (fuse_residual_conn) {
870-
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
871-
dst_memory_p =
872-
handler.AcquireDstMemoryWithResidual(output, residual_param);
873-
} else {
874-
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
875-
}
876-
877-
auto conv_p = handler.AcquireForwardPrimitive();
878-
879-
std::unordered_map<int, dnnl::memory> args = {
880-
{DNNL_ARG_SRC, *src_memory_p},
881-
{DNNL_ARG_WEIGHTS, *weights_memory_p},
882-
{DNNL_ARG_DST, *dst_memory_p}};
883-
884-
if (bias) {
885-
auto bias_memory_p =
886-
handler.AcquireBiasMemoryWithReorder(bias, is_test);
887-
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
888-
}
889-
890-
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
891-
conv_p->execute(astream, args);
892-
astream.wait();
893-
894-
output->set_mem_desc(dst_memory_p->get_desc());
895-
}));
896-
}
897-
898-
template <typename T_out>
899-
void ComputeINT8(const framework::ExecutionContext& ctx) const {
900-
auto& dev_ctx =
901-
ctx.template device_context<platform::MKLDNNDeviceContext>();
902-
const auto& mkldnn_engine = dev_ctx.GetEngine();
903-
904-
const std::string& fuse_activation =
905-
ctx.Attr<std::string>("fuse_activation");
906-
const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
907-
const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
908-
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
909-
910-
bool unsigned_output =
911-
(fuse_activation == "relu" || fuse_activation == "relu6");
912-
bool need_s8_to_u8 = false;
913-
914-
PADDLE_ENFORCE_NE(
915-
is_conv3d,
916-
true,
917-
platform::errors::Unimplemented(
918-
"OneDNN int8 convolution does not support 3D inputs currently"));
919-
PADDLE_ENFORCE_EQ(
920-
fuse_residual_conn && force_fp32_output,
921-
false,
922-
platform::errors::Unimplemented(
923-
"residual fusion does not support force output with fp32"));
924-
925-
auto* input = ctx.Input<phi::DenseTensor>("Input");
926-
auto* filter = ctx.Input<phi::DenseTensor>("Filter");
927-
auto* bias =
928-
ctx.HasInput("Bias") ? ctx.Input<phi::DenseTensor>("Bias") : nullptr;
929-
auto* output = ctx.Output<phi::DenseTensor>("Output");
930-
931-
PD_VISIT_FLOAT_AND_INT8_TYPES(
932-
filter->dtype(), "ConvMKLDNNHandlerT", ([&] {
933-
ConvMKLDNNHandlerT<T, data_t, T_out> handler(
934-
ctx,
935-
dev_ctx,
936-
mkldnn_engine,
937-
ctx.GetPlace(),
938-
input,
939-
filter,
940-
bias,
941-
output,
942-
ctx.InputName("Input") + ctx.InputName("Filter"));
943-
944-
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
945-
946-
const auto& scale_weights_data =
947-
ctx.Attr<std::vector<float>>("Scale_weights");
948-
const bool is_multi_channel = scale_weights_data.size() > 1;
949-
const int& groups = ctx.Attr<int>("groups");
950-
int mask_reorder =
951-
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0)
952-
: 0;
953-
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
954-
filter, groups, false, true, scale_weights_data, mask_reorder);
955-
956-
std::shared_ptr<dnnl::memory> dst_memory_p;
957-
if (fuse_residual_conn) {
958-
auto* residual_param = ctx.Input<phi::DenseTensor>("ResidualData");
959-
PADDLE_ENFORCE_EQ(
960-
output->dims(),
961-
residual_param->dims(),
962-
platform::errors::InvalidArgument(
963-
"Output and elementwise parameter need to have the "
964-
"same dimension sizes, but got output's dimension = %d"
965-
" and residual param's dimension =%d .",
966-
output->dims().size(),
967-
residual_param->dims().size()));
968-
dst_memory_p =
969-
handler.AcquireDstMemoryWithResidual(output, residual_param);
970-
need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
971-
dnnl::memory::data_type::s8) &&
972-
unsigned_output;
973-
} else {
974-
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
975-
}
976-
977-
auto conv_p = handler.AcquireForwardPrimitive();
978-
979-
std::unordered_map<int, dnnl::memory> args = {
980-
{DNNL_ARG_SRC, *src_memory_p},
981-
{DNNL_ARG_WEIGHTS, *weights_memory_p},
982-
{DNNL_ARG_DST, *dst_memory_p}};
983-
984-
if (bias) {
985-
std::vector<float> bias_scales;
986-
auto p_scales_tuple =
987-
std::make_shared<std::tuple<float, std::vector<float>>>(
988-
std::make_tuple(static_cast<float>(mask_reorder),
989-
bias_scales));
990-
if (ctx.HasAttr("Bias_scales")) {
991-
bias_scales = ctx.Attr<std::vector<float>>("Bias_scales");
992-
p_scales_tuple =
993-
std::make_shared<std::tuple<float, std::vector<float>>>(
994-
std::make_tuple(static_cast<float>(mask_reorder),
995-
bias_scales));
996-
} else {
997-
p_scales_tuple = handler.get_int8_bias_scales(ctx);
998-
}
999-
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
1000-
bias,
1001-
true,
1002-
std::get<1>(*p_scales_tuple),
1003-
std::get<0>(*p_scales_tuple));
1004-
args.insert({DNNL_ARG_BIAS, *bias_memory_p});
1005-
}
1006-
1007-
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
1008-
conv_p->execute(astream, args);
1009-
astream.wait();
1010-
1011-
if (need_s8_to_u8) {
1012-
output->mutable_data<uint8_t>(ctx.GetPlace());
1013-
}
1014-
1015-
output->set_mem_desc(dst_memory_p->get_desc());
1016-
}));
1017-
}
1018-
};
1019-
1020779
#define PD_VISIT_FLOAT_AND_BF16_TYPES(TYPE, NAME, ...) \
1021780
[&] { \
1022781
const auto& __dtype__ = TYPE; \
@@ -1184,25 +943,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
1184943

1185944
namespace ops = paddle::operators;
1186945

1187-
REGISTER_OP_KERNEL(depthwise_conv2d,
1188-
MKLDNN,
1189-
::paddle::platform::CPUPlace,
1190-
ops::ConvMKLDNNOpKernel<float>,
1191-
ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
1192-
ops::ConvMKLDNNOpKernel<uint8_t>,
1193-
ops::ConvMKLDNNOpKernel<int8_t>);
1194-
1195946
REGISTER_OP_KERNEL(depthwise_conv2d_grad,
1196947
MKLDNN,
1197948
::paddle::platform::CPUPlace,
1198949
ops::ConvMKLDNNGradOpKernel<float>,
1199950
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
1200951

1201-
REGISTER_OP_KERNEL(conv3d,
1202-
MKLDNN,
1203-
::paddle::platform::CPUPlace,
1204-
ops::ConvMKLDNNOpKernel<float>);
1205-
1206952
REGISTER_OP_KERNEL(conv3d_grad,
1207953
MKLDNN,
1208954
::paddle::platform::CPUPlace,

paddle/phi/kernels/onednn/conv_kernel.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,52 @@ void ConvKernel(const Context& dev_ctx,
424424
}
425425
}
426426

427+
template <typename T, typename Context>
428+
void DepthwiseConvKernel(const Context& dev_ctx,
429+
const DenseTensor& input,
430+
const DenseTensor& filter,
431+
const std::vector<int>& strides,
432+
const std::vector<int>& paddings,
433+
const std::string& padding_algorithm,
434+
int groups,
435+
const std::vector<int>& dilations,
436+
const std::string& data_format,
437+
DenseTensor* out) {
438+
ConvKernel<T, Context>(dev_ctx,
439+
input,
440+
filter,
441+
strides,
442+
paddings,
443+
padding_algorithm,
444+
dilations,
445+
groups,
446+
data_format,
447+
out);
448+
}
449+
450+
template <typename T, typename Context>
451+
void Conv3DKernel(const Context& dev_ctx,
452+
const DenseTensor& input,
453+
const DenseTensor& filter,
454+
const std::vector<int>& strides,
455+
const std::vector<int>& paddings,
456+
const std::string& padding_algorithm,
457+
int groups,
458+
const std::vector<int>& dilations,
459+
const std::string& data_format,
460+
DenseTensor* out) {
461+
ConvKernel<T, Context>(dev_ctx,
462+
input,
463+
filter,
464+
strides,
465+
paddings,
466+
padding_algorithm,
467+
dilations,
468+
groups,
469+
data_format,
470+
out);
471+
}
472+
427473
} // namespace phi
428474

429475
PD_REGISTER_KERNEL(conv2d,
@@ -434,3 +480,14 @@ PD_REGISTER_KERNEL(conv2d,
434480
phi::dtype::bfloat16,
435481
uint8_t,
436482
int8_t) {}
483+
484+
PD_REGISTER_KERNEL(depthwise_conv2d,
485+
OneDNN,
486+
ONEDNN,
487+
phi::DepthwiseConvKernel,
488+
float,
489+
phi::dtype::bfloat16,
490+
uint8_t,
491+
int8_t) {}
492+
493+
PD_REGISTER_KERNEL(conv3d, OneDNN, ONEDNN, phi::Conv3DKernel, float) {}

0 commit comments

Comments
 (0)