@@ -776,247 +776,6 @@ class ConvMKLDNNHandlerT
776
776
777
777
} // anonymous namespace
778
778
779
- #define PD_VISIT_FLOAT_AND_INT8_TYPES (TYPE, NAME, ...) \
780
- [&] { \
781
- const auto & __dtype__ = TYPE; \
782
- switch (__dtype__) { \
783
- PD_PRIVATE_CASE_TYPE ( \
784
- NAME, ::paddle::DataType::FLOAT32, float , __VA_ARGS__) \
785
- PD_PRIVATE_CASE_TYPE ( \
786
- NAME, ::paddle::DataType::INT8, int8_t , __VA_ARGS__) \
787
- default : \
788
- PD_THROW (" function " #NAME " is not implemented for data type `" , \
789
- __dtype__, \
790
- " `" ); \
791
- } \
792
- }()
793
-
794
- template <typename T>
795
- class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
796
- public:
797
- void Compute (const framework::ExecutionContext& ctx) const override {
798
- PADDLE_ENFORCE_EQ (platform::is_cpu_place (ctx.GetPlace ()),
799
- true ,
800
- platform::errors::PreconditionNotMet (
801
- " Operator DNNL Conv must use CPUPlace" ));
802
- bool is_INT8 =
803
- std::is_same<T, int8_t >::value || std::is_same<T, uint8_t >::value;
804
- bool is_BFLOAT16 = ctx.Attr <std::string>(" mkldnn_data_type" ) == " bfloat16" ;
805
- auto residual_param = ctx.Input <phi::DenseTensor>(" ResidualData" );
806
- bool fuse_residual_conn = ctx.Attr <bool >(" fuse_residual_connection" );
807
- std::string fuse_activation = ctx.Attr <std::string>(" fuse_activation" );
808
- bool force_fp32_output = ctx.Attr <bool >(" force_fp32_output" );
809
- auto dst_dt = GetDstType (is_INT8,
810
- is_BFLOAT16,
811
- force_fp32_output,
812
- fuse_activation,
813
- fuse_residual_conn,
814
- residual_param);
815
- if (!is_INT8) {
816
- if (dst_dt == dnnl::memory::data_type::f32 ) {
817
- ComputeFP32<float >(ctx);
818
- } else if (dst_dt == dnnl::memory::data_type::bf16 ) {
819
- ComputeFP32<platform::bfloat16>(ctx);
820
- }
821
- } else {
822
- if (dst_dt == dnnl::memory::data_type::f32 ) {
823
- ComputeINT8<float >(ctx);
824
- } else if (dst_dt == dnnl::memory::data_type::u8 ) {
825
- ComputeINT8<uint8_t >(ctx);
826
- } else if (dst_dt == dnnl::memory::data_type::s8) {
827
- ComputeINT8<int8_t >(ctx);
828
- }
829
- }
830
- }
831
-
832
- template <typename T_out>
833
- void ComputeFP32 (const framework::ExecutionContext& ctx) const {
834
- auto & dev_ctx =
835
- ctx.template device_context <platform::MKLDNNDeviceContext>();
836
- const auto & mkldnn_engine = dev_ctx.GetEngine ();
837
-
838
- bool is_test = ctx.Attr <bool >(" is_test" );
839
- const auto & strides = ctx.Attr <std::vector<int >>(" strides" );
840
- bool is_conv3d = strides.size () == 3UL ;
841
- bool fuse_residual_conn = ctx.Attr <bool >(" fuse_residual_connection" );
842
- int groups = ctx.Attr <int >(" groups" );
843
-
844
- const auto * input = ctx.Input <phi::DenseTensor>(" Input" );
845
- const auto * filter = ctx.Input <phi::DenseTensor>(" Filter" );
846
- const auto * bias =
847
- ctx.HasInput (" Bias" ) ? ctx.Input <phi::DenseTensor>(" Bias" ) : nullptr ;
848
- auto * output = ctx.Output <phi::DenseTensor>(" Output" );
849
-
850
- PD_VISIT_FLOAT_AND_INT8_TYPES (
851
- filter->dtype (), " ConvMKLDNNHandlerT" , ([&] {
852
- ConvMKLDNNHandlerT<T, data_t , T_out> handler (
853
- ctx,
854
- dev_ctx,
855
- mkldnn_engine,
856
- ctx.GetPlace (),
857
- input,
858
- filter,
859
- bias,
860
- output,
861
- ctx.InputName (" Input" ) + ctx.InputName (" Filter" ));
862
-
863
- auto src_memory_p = handler.AcquireSrcMemoryWithReorder (input);
864
-
865
- auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder (
866
- filter, groups, is_conv3d, is_test);
867
-
868
- std::shared_ptr<dnnl::memory> dst_memory_p;
869
- if (fuse_residual_conn) {
870
- auto * residual_param = ctx.Input <phi::DenseTensor>(" ResidualData" );
871
- dst_memory_p =
872
- handler.AcquireDstMemoryWithResidual (output, residual_param);
873
- } else {
874
- dst_memory_p = handler.template AcquireDstMemory <T_out>(output);
875
- }
876
-
877
- auto conv_p = handler.AcquireForwardPrimitive ();
878
-
879
- std::unordered_map<int , dnnl::memory> args = {
880
- {DNNL_ARG_SRC, *src_memory_p},
881
- {DNNL_ARG_WEIGHTS, *weights_memory_p},
882
- {DNNL_ARG_DST, *dst_memory_p}};
883
-
884
- if (bias) {
885
- auto bias_memory_p =
886
- handler.AcquireBiasMemoryWithReorder (bias, is_test);
887
- args.insert ({DNNL_ARG_BIAS, *bias_memory_p});
888
- }
889
-
890
- auto & astream = platform::MKLDNNDeviceContext::tls ().get_stream ();
891
- conv_p->execute (astream, args);
892
- astream.wait ();
893
-
894
- output->set_mem_desc (dst_memory_p->get_desc ());
895
- }));
896
- }
897
-
898
- template <typename T_out>
899
- void ComputeINT8 (const framework::ExecutionContext& ctx) const {
900
- auto & dev_ctx =
901
- ctx.template device_context <platform::MKLDNNDeviceContext>();
902
- const auto & mkldnn_engine = dev_ctx.GetEngine ();
903
-
904
- const std::string& fuse_activation =
905
- ctx.Attr <std::string>(" fuse_activation" );
906
- const bool & fuse_residual_conn = ctx.Attr <bool >(" fuse_residual_connection" );
907
- const bool & force_fp32_output = ctx.Attr <bool >(" force_fp32_output" );
908
- const bool is_conv3d = ctx.Attr <std::vector<int >>(" strides" ).size () == 3U ;
909
-
910
- bool unsigned_output =
911
- (fuse_activation == " relu" || fuse_activation == " relu6" );
912
- bool need_s8_to_u8 = false ;
913
-
914
- PADDLE_ENFORCE_NE (
915
- is_conv3d,
916
- true ,
917
- platform::errors::Unimplemented (
918
- " OneDNN int8 convolution does not support 3D inputs currently" ));
919
- PADDLE_ENFORCE_EQ (
920
- fuse_residual_conn && force_fp32_output,
921
- false ,
922
- platform::errors::Unimplemented (
923
- " residual fusion does not support force output with fp32" ));
924
-
925
- auto * input = ctx.Input <phi::DenseTensor>(" Input" );
926
- auto * filter = ctx.Input <phi::DenseTensor>(" Filter" );
927
- auto * bias =
928
- ctx.HasInput (" Bias" ) ? ctx.Input <phi::DenseTensor>(" Bias" ) : nullptr ;
929
- auto * output = ctx.Output <phi::DenseTensor>(" Output" );
930
-
931
- PD_VISIT_FLOAT_AND_INT8_TYPES (
932
- filter->dtype (), " ConvMKLDNNHandlerT" , ([&] {
933
- ConvMKLDNNHandlerT<T, data_t , T_out> handler (
934
- ctx,
935
- dev_ctx,
936
- mkldnn_engine,
937
- ctx.GetPlace (),
938
- input,
939
- filter,
940
- bias,
941
- output,
942
- ctx.InputName (" Input" ) + ctx.InputName (" Filter" ));
943
-
944
- auto src_memory_p = handler.AcquireSrcMemoryWithReorder (input);
945
-
946
- const auto & scale_weights_data =
947
- ctx.Attr <std::vector<float >>(" Scale_weights" );
948
- const bool is_multi_channel = scale_weights_data.size () > 1 ;
949
- const int & groups = ctx.Attr <int >(" groups" );
950
- int mask_reorder =
951
- is_multi_channel ? ((groups != 1 ) ? (1 << 1 ) + (1 << 0 ) : 1 << 0 )
952
- : 0 ;
953
- auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder (
954
- filter, groups, false , true , scale_weights_data, mask_reorder);
955
-
956
- std::shared_ptr<dnnl::memory> dst_memory_p;
957
- if (fuse_residual_conn) {
958
- auto * residual_param = ctx.Input <phi::DenseTensor>(" ResidualData" );
959
- PADDLE_ENFORCE_EQ (
960
- output->dims (),
961
- residual_param->dims (),
962
- platform::errors::InvalidArgument (
963
- " Output and elementwise parameter need to have the "
964
- " same dimension sizes, but got output's dimension = %d"
965
- " and residual param's dimension =%d ." ,
966
- output->dims ().size (),
967
- residual_param->dims ().size ()));
968
- dst_memory_p =
969
- handler.AcquireDstMemoryWithResidual (output, residual_param);
970
- need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
971
- dnnl::memory::data_type::s8) &&
972
- unsigned_output;
973
- } else {
974
- dst_memory_p = handler.template AcquireDstMemory <T_out>(output);
975
- }
976
-
977
- auto conv_p = handler.AcquireForwardPrimitive ();
978
-
979
- std::unordered_map<int , dnnl::memory> args = {
980
- {DNNL_ARG_SRC, *src_memory_p},
981
- {DNNL_ARG_WEIGHTS, *weights_memory_p},
982
- {DNNL_ARG_DST, *dst_memory_p}};
983
-
984
- if (bias) {
985
- std::vector<float > bias_scales;
986
- auto p_scales_tuple =
987
- std::make_shared<std::tuple<float , std::vector<float >>>(
988
- std::make_tuple (static_cast <float >(mask_reorder),
989
- bias_scales));
990
- if (ctx.HasAttr (" Bias_scales" )) {
991
- bias_scales = ctx.Attr <std::vector<float >>(" Bias_scales" );
992
- p_scales_tuple =
993
- std::make_shared<std::tuple<float , std::vector<float >>>(
994
- std::make_tuple (static_cast <float >(mask_reorder),
995
- bias_scales));
996
- } else {
997
- p_scales_tuple = handler.get_int8_bias_scales (ctx);
998
- }
999
- auto bias_memory_p = handler.AcquireBiasMemoryWithReorder (
1000
- bias,
1001
- true ,
1002
- std::get<1 >(*p_scales_tuple),
1003
- std::get<0 >(*p_scales_tuple));
1004
- args.insert ({DNNL_ARG_BIAS, *bias_memory_p});
1005
- }
1006
-
1007
- auto & astream = platform::MKLDNNDeviceContext::tls ().get_stream ();
1008
- conv_p->execute (astream, args);
1009
- astream.wait ();
1010
-
1011
- if (need_s8_to_u8) {
1012
- output->mutable_data <uint8_t >(ctx.GetPlace ());
1013
- }
1014
-
1015
- output->set_mem_desc (dst_memory_p->get_desc ());
1016
- }));
1017
- }
1018
- };
1019
-
1020
779
#define PD_VISIT_FLOAT_AND_BF16_TYPES (TYPE, NAME, ...) \
1021
780
[&] { \
1022
781
const auto & __dtype__ = TYPE; \
@@ -1184,25 +943,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
1184
943
1185
944
namespace ops = paddle::operators;
1186
945
1187
- REGISTER_OP_KERNEL (depthwise_conv2d,
1188
- MKLDNN,
1189
- ::paddle::platform::CPUPlace,
1190
- ops::ConvMKLDNNOpKernel<float >,
1191
- ops::ConvMKLDNNOpKernel<paddle::platform::bfloat16>,
1192
- ops::ConvMKLDNNOpKernel<uint8_t >,
1193
- ops::ConvMKLDNNOpKernel<int8_t >);
1194
-
1195
946
REGISTER_OP_KERNEL (depthwise_conv2d_grad,
1196
947
MKLDNN,
1197
948
::paddle::platform::CPUPlace,
1198
949
ops::ConvMKLDNNGradOpKernel<float >,
1199
950
ops::ConvMKLDNNGradOpKernel<paddle::platform::bfloat16>);
1200
951
1201
- REGISTER_OP_KERNEL (conv3d,
1202
- MKLDNN,
1203
- ::paddle::platform::CPUPlace,
1204
- ops::ConvMKLDNNOpKernel<float >);
1205
-
1206
952
REGISTER_OP_KERNEL (conv3d_grad,
1207
953
MKLDNN,
1208
954
::paddle::platform::CPUPlace,
0 commit comments