Skip to content

Commit 7783d3b

Browse files
Conv refine (#20644)
* add condition judgement for performance improvement test=develop * add condition judgement for performance improvement test=develop * refine code style test=develop
1 parent 57b656f commit 7783d3b

File tree

1 file changed

+32
-28
lines changed

1 file changed

+32
-28
lines changed

paddle/fluid/operators/conv_cudnn_op.cu

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -540,23 +540,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
540540
workspace_size);
541541
}
542542

543-
std::vector<int> starts(transformed_input_channel.dims().size(), 0);
544-
std::vector<int> axes(transformed_input_channel.dims().size(), 0);
543+
if (!is_sys_pad) {
544+
std::vector<int> starts(transformed_input_channel.dims().size(), 0);
545+
std::vector<int> axes(transformed_input_channel.dims().size(), 0);
545546

546-
for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) {
547-
starts[i] = input_pad[2 * i];
548-
axes[i] = i;
549-
}
547+
for (size_t i = 0; i < transformed_input_channel.dims().size(); ++i) {
548+
starts[i] = input_pad[2 * i];
549+
axes[i] = i;
550+
}
550551

551-
transformed_input_grad_channel.mutable_data(ctx.GetPlace());
552-
if (transformed_input_channel.dims().size() == 4) {
553-
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
554-
ctx, &transformed_input_grad, &transformed_input_grad_channel,
555-
starts, axes);
556-
} else {
557-
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
558-
ctx, &transformed_input_grad, &transformed_input_grad_channel,
559-
starts, axes);
552+
transformed_input_grad_channel.mutable_data(ctx.GetPlace());
553+
if (transformed_input_channel.dims().size() == 4) {
554+
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
555+
ctx, &transformed_input_grad, &transformed_input_grad_channel,
556+
starts, axes);
557+
} else {
558+
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
559+
ctx, &transformed_input_grad, &transformed_input_grad_channel,
560+
starts, axes);
561+
}
560562
}
561563

562564
if (channel_last) {
@@ -982,20 +984,22 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
982984
workspace_size);
983985
}
984986

985-
// reverse padded input
986-
std::vector<int> starts(X->dims().size(), 0);
987-
std::vector<int> axes(X->dims().size(), 0);
987+
if (!is_sys_pad) {
988+
// reverse padded input
989+
std::vector<int> starts(X->dims().size(), 0);
990+
std::vector<int> axes(X->dims().size(), 0);
988991

989-
for (size_t i = 0; i < X->dims().size(); ++i) {
990-
starts[i] = input_pad[2 * i];
991-
axes[i] = i;
992-
}
993-
if (X->dims().size() == 4) {
994-
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
995-
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
996-
} else {
997-
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
998-
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
992+
for (size_t i = 0; i < X->dims().size(); ++i) {
993+
starts[i] = input_pad[2 * i];
994+
axes[i] = i;
995+
}
996+
if (X->dims().size() == 4) {
997+
Slice_2<paddle::platform::CUDADeviceContext, T, 4>(
998+
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
999+
} else {
1000+
Slice_2<paddle::platform::CUDADeviceContext, T, 5>(
1001+
ctx, &transformed_dX, &transformed_dX_channel, starts, axes);
1002+
}
9991003
}
10001004
if (channel_last) {
10011005
TransToChannelLast<paddle::platform::CUDADeviceContext, T>(

0 commit comments

Comments
 (0)