@@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
383
383
const std::string key_conv_pd = key + " @conv_pd" ;
384
384
385
385
bool need_s8_to_u8 = false ;
386
- std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr ;
387
- std::shared_ptr<mkldnn::memory> src_memory_p = nullptr ;
388
- std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr ;
389
- std::shared_ptr<mkldnn::memory> dst_memory_p = nullptr ;
386
+ std::shared_ptr<mkldnn::convolution_forward> conv_p;
387
+ std::shared_ptr<mkldnn::memory> src_memory_p;
388
+ std::shared_ptr<mkldnn::memory> user_src_memory_p;
389
+ std::shared_ptr<mkldnn::memory> dst_memory_p;
390
390
std::vector<primitive> pipeline;
391
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
392
- nullptr ;
393
- std::shared_ptr<platform::ConvMKLDNNHandler> handler = nullptr ;
391
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
392
+ std::shared_ptr<platform::ConvMKLDNNHandler> handler;
394
393
395
394
auto prim_key = key + " @conv_p" ;
396
395
auto dst_key = key + " @dst_mem_p" ;
@@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
460
459
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
461
460
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
462
461
// v0.20 is enabled
462
+ std::shared_ptr<memory::desc> bias_md_p;
463
463
if (bias) {
464
464
bias_tz = paddle::framework::vectorize2int (bias->dims ());
465
- auto bias_md = platform::MKLDNNMemDesc (bias_tz, memory::data_type::s32,
466
- memory::format::x);
467
-
468
- conv_pd = ConvFwdPrimitiveDesc (
469
- src_md, weights_md, bias_md, dst_md, strides, paddings,
470
- mkldnn_engine, fuse_relu || fuse_brelu /* fuse_relu*/ ,
471
- fuse_residual_conn, false /* fuse_brelu*/ , fuse_brelu_threshold,
472
- output_shift_scale, sum_scale, is_test);
473
-
474
- } else {
475
- conv_pd = ConvFwdPrimitiveDesc (
476
- src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
477
- fuse_relu || fuse_brelu /* fuse_relu*/ , fuse_residual_conn,
478
- false /* fuse_brelu*/ , fuse_brelu_threshold, output_shift_scale,
479
- sum_scale, is_test);
465
+ bias_md_p = std::make_shared<memory::desc>(platform::MKLDNNMemDesc (
466
+ bias_tz, memory::data_type::s32, memory::format::x));
480
467
}
468
+ conv_pd = ConvFwdPrimitiveDesc (
469
+ src_md, weights_md, bias_md_p, dst_md, strides, paddings,
470
+ mkldnn_engine, fuse_relu || fuse_brelu /* fuse_relu*/ ,
471
+ fuse_residual_conn, false /* fuse_brelu*/ , fuse_brelu_threshold,
472
+ output_shift_scale, sum_scale, is_test);
481
473
// Save conv_pd/src_memory/weights_memory for backward pass
482
474
dev_ctx.SetBlob (key_conv_pd, conv_pd);
483
475
handler.reset (new platform::ConvMKLDNNHandler (conv_pd, dev_ctx,
@@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
649
641
private:
650
642
mkldnn::primitive_attr CreatePostOps (
651
643
bool fuse_relu, bool fuse_residual_conn,
652
- const std::vector<float > output_shift_scale, float sum_scale,
644
+ const std::vector<float >& output_shift_scale, float sum_scale,
653
645
bool fuse_brelu, float fuse_brelu_threshold) const {
654
646
mkldnn::primitive_attr conv_attr;
655
647
mkldnn::post_ops post_operations;
@@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
679
671
680
672
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
681
673
ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
674
+ const std::shared_ptr<memory::desc> bias_md_p,
682
675
const memory::desc& dst, const std::vector<int >& strides,
683
676
const std::vector<int >& paddings,
684
677
const mkldnn::engine& engine, const bool fuse_relu,
685
678
const bool fuse_residual_conn, const bool fuse_brelu,
686
679
const float fuse_brelu_threshold,
687
- const std::vector<float > output_shift_scale,
680
+ const std::vector<float >& output_shift_scale,
688
681
const float sum_scale, bool is_test) const {
689
682
memory::dims stride_dims = {strides[0 ], strides[1 ]};
690
683
memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
691
684
692
685
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
693
686
: mkldnn::prop_kind::forward_training;
694
-
695
- auto conv_desc = mkldnn::convolution_forward::desc (
696
- propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
697
- padding_dims, padding_dims, mkldnn::padding_kind::zero);
698
- mkldnn::primitive_attr conv_attr =
699
- CreatePostOps (fuse_relu, fuse_residual_conn, output_shift_scale,
700
- sum_scale, fuse_brelu, fuse_brelu_threshold);
701
-
702
- auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc (
703
- conv_desc, conv_attr, engine);
704
-
705
- return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
706
- p_conv_pd);
707
- }
708
-
709
- std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
710
- ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
711
- const memory::desc& bias, const memory::desc& dst,
712
- const std::vector<int >& strides,
713
- const std::vector<int >& paddings,
714
- const mkldnn::engine& engine, const bool fuse_relu,
715
- const bool fuse_residual_conn, const bool fuse_brelu,
716
- const float fuse_brelu_threshold,
717
- const std::vector<float > output_shift_scale,
718
- const float sum_scale, bool is_test) const {
719
- memory::dims stride_dims = {strides[0 ], strides[1 ]};
720
- memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
721
-
722
- auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
723
- : mkldnn::prop_kind::forward_training;
724
-
725
- auto conv_desc = mkldnn::convolution_forward::desc (
726
- propagation, mkldnn::convolution_direct, src, weights, bias, dst,
727
- stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
687
+ auto conv_desc =
688
+ (bias_md_p != nullptr )
689
+ ? mkldnn::convolution_forward::desc (
690
+ propagation, mkldnn::convolution_direct, src, weights,
691
+ (*bias_md_p), dst, stride_dims, padding_dims, padding_dims,
692
+ mkldnn::padding_kind::zero)
693
+ : mkldnn::convolution_forward::desc (
694
+ propagation, mkldnn::convolution_direct, src, weights, dst,
695
+ stride_dims, padding_dims, padding_dims,
696
+ mkldnn::padding_kind::zero);
728
697
729
698
mkldnn::primitive_attr conv_attr =
730
699
CreatePostOps (fuse_relu, fuse_residual_conn, output_shift_scale,
0 commit comments