Skip to content

Commit f8ecc3d

Browse files
lidanqing-intelluotao1
authored andcommitted
refactor the function ConvFwdPrimitiveDesc (#17897)
* refractor the function ConvFwdPrimitiveDesc test=develop * change according to review test=develop * use pointer way without boost::optional test=develop * pass vector to function by reference instead of raw vector test=develop * change pointer to shared_ptr test=develop
1 parent 8462e2b commit f8ecc3d

File tree

1 file changed

+27
-58
lines changed

1 file changed

+27
-58
lines changed

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 27 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
383383
const std::string key_conv_pd = key + "@conv_pd";
384384

385385
bool need_s8_to_u8 = false;
386-
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
387-
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr;
388-
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr;
389-
std::shared_ptr<mkldnn::memory> dst_memory_p = nullptr;
386+
std::shared_ptr<mkldnn::convolution_forward> conv_p;
387+
std::shared_ptr<mkldnn::memory> src_memory_p;
388+
std::shared_ptr<mkldnn::memory> user_src_memory_p;
389+
std::shared_ptr<mkldnn::memory> dst_memory_p;
390390
std::vector<primitive> pipeline;
391-
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
392-
nullptr;
393-
std::shared_ptr<platform::ConvMKLDNNHandler> handler = nullptr;
391+
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
392+
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
394393

395394
auto prim_key = key + "@conv_p";
396395
auto dst_key = key + "@dst_mem_p";
@@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
460459
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
461460
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
462461
// v0.20 is enabled
462+
std::shared_ptr<memory::desc> bias_md_p;
463463
if (bias) {
464464
bias_tz = paddle::framework::vectorize2int(bias->dims());
465-
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
466-
memory::format::x);
467-
468-
conv_pd = ConvFwdPrimitiveDesc(
469-
src_md, weights_md, bias_md, dst_md, strides, paddings,
470-
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
471-
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
472-
output_shift_scale, sum_scale, is_test);
473-
474-
} else {
475-
conv_pd = ConvFwdPrimitiveDesc(
476-
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
477-
fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn,
478-
false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale,
479-
sum_scale, is_test);
465+
bias_md_p = std::make_shared<memory::desc>(platform::MKLDNNMemDesc(
466+
bias_tz, memory::data_type::s32, memory::format::x));
480467
}
468+
conv_pd = ConvFwdPrimitiveDesc(
469+
src_md, weights_md, bias_md_p, dst_md, strides, paddings,
470+
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
471+
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
472+
output_shift_scale, sum_scale, is_test);
481473
// Save conv_pd/src_memory/weights_memory for backward pass
482474
dev_ctx.SetBlob(key_conv_pd, conv_pd);
483475
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
@@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
649641
private:
650642
mkldnn::primitive_attr CreatePostOps(
651643
bool fuse_relu, bool fuse_residual_conn,
652-
const std::vector<float> output_shift_scale, float sum_scale,
644+
const std::vector<float>& output_shift_scale, float sum_scale,
653645
bool fuse_brelu, float fuse_brelu_threshold) const {
654646
mkldnn::primitive_attr conv_attr;
655647
mkldnn::post_ops post_operations;
@@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
679671

680672
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
681673
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
674+
const std::shared_ptr<memory::desc> bias_md_p,
682675
const memory::desc& dst, const std::vector<int>& strides,
683676
const std::vector<int>& paddings,
684677
const mkldnn::engine& engine, const bool fuse_relu,
685678
const bool fuse_residual_conn, const bool fuse_brelu,
686679
const float fuse_brelu_threshold,
687-
const std::vector<float> output_shift_scale,
680+
const std::vector<float>& output_shift_scale,
688681
const float sum_scale, bool is_test) const {
689682
memory::dims stride_dims = {strides[0], strides[1]};
690683
memory::dims padding_dims = {paddings[0], paddings[1]};
691684

692685
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
693686
: mkldnn::prop_kind::forward_training;
694-
695-
auto conv_desc = mkldnn::convolution_forward::desc(
696-
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
697-
padding_dims, padding_dims, mkldnn::padding_kind::zero);
698-
mkldnn::primitive_attr conv_attr =
699-
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
700-
sum_scale, fuse_brelu, fuse_brelu_threshold);
701-
702-
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
703-
conv_desc, conv_attr, engine);
704-
705-
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
706-
p_conv_pd);
707-
}
708-
709-
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
710-
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
711-
const memory::desc& bias, const memory::desc& dst,
712-
const std::vector<int>& strides,
713-
const std::vector<int>& paddings,
714-
const mkldnn::engine& engine, const bool fuse_relu,
715-
const bool fuse_residual_conn, const bool fuse_brelu,
716-
const float fuse_brelu_threshold,
717-
const std::vector<float> output_shift_scale,
718-
const float sum_scale, bool is_test) const {
719-
memory::dims stride_dims = {strides[0], strides[1]};
720-
memory::dims padding_dims = {paddings[0], paddings[1]};
721-
722-
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
723-
: mkldnn::prop_kind::forward_training;
724-
725-
auto conv_desc = mkldnn::convolution_forward::desc(
726-
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
727-
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
687+
auto conv_desc =
688+
(bias_md_p != nullptr)
689+
? mkldnn::convolution_forward::desc(
690+
propagation, mkldnn::convolution_direct, src, weights,
691+
(*bias_md_p), dst, stride_dims, padding_dims, padding_dims,
692+
mkldnn::padding_kind::zero)
693+
: mkldnn::convolution_forward::desc(
694+
propagation, mkldnn::convolution_direct, src, weights, dst,
695+
stride_dims, padding_dims, padding_dims,
696+
mkldnn::padding_kind::zero);
728697

729698
mkldnn::primitive_attr conv_attr =
730699
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,

0 commit comments

Comments
 (0)