Skip to content

[static op generation] pool2d, pool3d #54070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a9ba11a
[phi] autogen code pool2d
gouzil May 23, 2023
01f4c46
[phi] move poolop GetExpectedKernelType and GetKernelTypeForVar
gouzil May 23, 2023
7f7fabc
[phi] fix move GetExpectedKernelType
gouzil May 23, 2023
fc3dbc7
[phi] autogen code pool3d
gouzil May 23, 2023
dbda093
[fluid] fix get_expected_kernel_func
gouzil May 24, 2023
1399e2a
[fluid] fix get_expected_kernel_func
gouzil May 24, 2023
ddf15ba
[phi][yml] fix args
gouzil May 24, 2023
e81a123
[fluid] rm pool_op cmake
gouzil May 24, 2023
94eabd3
[phi] fix GetKernelTypeForVar
gouzil May 24, 2023
a0399b1
[phi] fix args; reduction legacy
gouzil May 25, 2023
2099624
[phi] fix pool2d_double_grad args
gouzil May 25, 2023
010194d
[phi] fix pool3d args; reduction pool3d legacy; fix test
gouzil May 25, 2023
6f615f0
clean
gouzil May 25, 2023
83ffe1d
[phi] fix yaml add attrs ; [test] fix cmake
gouzil May 25, 2023
19444d3
[phi] fix pool3d args
gouzil May 25, 2023
b8ab81c
[phi] fix pool2d_double_grad args
gouzil May 25, 2023
0a743bf
Merge branch 'develop' of github.com:gouzil/Paddle into autogen_code_…
gouzil May 28, 2023
ec9cfe5
[phi] op_compat add keep_signature; RollBACK pool_sig
gouzil May 29, 2023
d263253
[phi] RollBACK output
gouzil May 29, 2023
23a8700
[phi] try fix multiple definition
gouzil May 29, 2023
791de7e
[phi] fix keep_signature; RollBACK op_compat
gouzil May 29, 2023
39f9ec1
[phi] fix manual_signature
gouzil May 29, 2023
8c4810f
Merge branch 'develop' of github.com:gouzil/Paddle into autogen_code_…
gouzil May 30, 2023
b1d9b1b
Merge branches 'autogen_code_pool2d' and 'develop' of github.com:gouz…
gouzil May 30, 2023
b51b2ac
[phi]clean get_expected_kernel_func;[phi] rm compat;
gouzil May 31, 2023
e0baf77
Merge branch 'develop' of github.com:gouzil/Paddle into autogen_code_…
gouzil Jun 1, 2023
d4ce114
[phi] Fix Misdeletion
gouzil Jun 1, 2023
36f722f
Merge branch 'develop' of github.com:gouzil/Paddle into autogen_code_…
gouzil Jun 2, 2023
8b81c16
add use_cudnn
gouzil Jun 2, 2023
fd262b2
Merge branch 'develop' of github.com:gouzil/Paddle into autogen_code_…
gouzil Jun 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ static bool ReduceOpHasOptimizedOneDNNKernel(
return true;
}

// only poolop
bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) {
if (ctx.Attr<bool>("adaptive") == false) return true;
// oneDNN is supporting only unchangable in size pool window
auto src_tz = phi::vectorize(ctx.Input<phi::DenseTensor>("X")->dims());
if (!ctx.HasAttr("ksize")) {
return false;
}
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
// Fast but not exhustive check
return ((src_tz[src_tz.size() - 1] % ksize[1] == 0) &&
(src_tz[src_tz.size() - 2] % ksize[0] == 0));
}

phi::KernelKey GetReduceExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
Expand Down Expand Up @@ -117,6 +131,31 @@ phi::KernelKey GetAssignExpectedKernelType(
ctx.device_context().GetPlace());
}

phi::KernelKey GetPoolExpectedKernelType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加

phi::KernelKey GetPoolDoubleGradExpectedKernelType(
    const framework::ExecutionContext& ctx,
    const framework::OperatorWithKernel* op_ptr) {
  auto data_type = op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "grad_x@GRAD");

  // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
  op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
  // NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN

  return phi::KernelKey(data_type, ctx.GetPlace());
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
auto data_type = op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X");

// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu) END: Above codes originally enclosed by PADDLE_WITH_MKLDNN

return phi::KernelKey(data_type, ctx.GetPlace());
}

phi::KernelKey GetPoolGradExpectedKernelType(
const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除这个,逻辑跟GetPoolExpectedKernelType重复

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const framework::OperatorWithKernel* op_ptr) {
auto input_data_type =
op_ptr->OperatorWithKernel::IndicateVarDataType(ctx, "X");

// NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
op_ptr->SetDnnFallback(!CanMKLDNNSupportPool(ctx));
// NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN

return phi::KernelKey(input_data_type, ctx.GetPlace());
}

phi::KernelKey GetSgdExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/operators/generator/get_expected_kernel_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ phi::KernelKey GetAssignExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetPoolExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetPoolGradExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);

phi::KernelKey GetSgdExpectedKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel* op_ptr);
Expand Down
Loading