Skip to content

Commit 7f72de7

Browse files
committed
[PHI] Enable depthwise convolution cudnn
1 parent 9bf45f1 commit 7f72de7

File tree

6 files changed

+241
-55
lines changed

6 files changed

+241
-55
lines changed

paddle/phi/kernels/gpu/depthwise_conv.h

+70
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,76 @@ namespace cub = hipcub;
2929
#include "paddle/phi/backends/gpu/gpu_device_function.h"
3030
#include "paddle/phi/backends/gpu/gpu_primitives.h"
3131
#include "paddle/phi/kernels/funcs/math_function.h"
32+
#include "paddle/phi/kernels/gpudnn/conv_gpudnn.h"
33+
34+
namespace phi {
35+
// To determine use cudnn or not.
36+
struct DWConvParams {
37+
bool has_fuse_relu_;
38+
std::string data_format_;
39+
std::vector<int> strides_;
40+
std::vector<int> dilations_;
41+
42+
DWConvParams(const bool has_fuse_relu,
43+
const std::string& data_format,
44+
const std::vector<int>& strides,
45+
const std::vector<int>& dilations)
46+
: has_fuse_relu_(has_fuse_relu),
47+
data_format_(data_format),
48+
strides_(strides),
49+
dilations_(dilations) {}
50+
51+
bool is_strided() const {
52+
for (const auto& stride : strides_) {
53+
if (stride != 1) return true;
54+
}
55+
return false;
56+
}
57+
58+
bool is_dilated() const {
59+
for (const auto& dilation : dilations_) {
60+
if (dilation != 1) return true;
61+
}
62+
return false;
63+
}
64+
65+
// Use cudnn for NHWC and NCHW FP16.
66+
bool UseCudnnDepthwise(const DenseTensor& input,
67+
const DenseTensor& filter) const {
68+
// No fuse supported yet.
69+
if (has_fuse_relu_) {
70+
return false;
71+
}
72+
// Cudnn enable
73+
if (!dynload::HasCUDNN()) {
74+
return false;
75+
}
76+
// Use cudnn depthwise conv for channel last format.
77+
if (data_format_ == "NHWC") {
78+
return true;
79+
}
80+
// Only support FP16.
81+
if (input.type() != phi::DataType::FLOAT16 &&
82+
filter.type() != phi::DataType::FLOAT16) {
83+
return false;
84+
}
85+
// Only support depthwise 2D.
86+
if (input.dims().size() != 4) {
87+
return false;
88+
}
89+
// No dilation and stride.
90+
if (is_dilated() || is_strided()) {
91+
return false;
92+
}
93+
// Format here is NCHW, channel greater than 32, need benchmarks.
94+
if (input.dims()[1] < 32) {
95+
return false;
96+
}
97+
return true;
98+
}
99+
};
100+
101+
} // namespace phi
32102

33103
namespace paddle {
34104
namespace operators {

paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu

+31
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,37 @@ void DepthwiseConvGradKernel(const Context& dev_ctx,
5252
std::vector<int> paddings = paddings_t;
5353
std::vector<int> dilations = dilations_t;
5454

55+
// Enable if cudnn above 8.2, hip already has cudnn kernel.
56+
#if defined(CUDNN_VERSION) && CUDNN_VERSION_MIN(8, 2, 0) && \
57+
!defined(PADDLE_WITH_HIP)
58+
DWConvParams params(has_fuse_relu, data_format, strides, dilations);
59+
if (params.UseCudnnDepthwise(input, filter)) {
60+
// Keep same with original kernel.
61+
phi::funcs::SetConstant<Context, T> set_zero;
62+
if (input_grad) {
63+
dev_ctx.template Alloc<T>(input_grad);
64+
set_zero(dev_ctx, input_grad, static_cast<T>(0));
65+
}
66+
if (filter_grad) {
67+
dev_ctx.template Alloc<T>(filter_grad);
68+
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
69+
}
70+
phi::DepthwiseConvCudnnGradKernel<T>(dev_ctx,
71+
input,
72+
filter,
73+
*output_grad,
74+
strides_t,
75+
paddings_t,
76+
padding_algorithm,
77+
groups,
78+
dilations_t,
79+
data_format,
80+
input_grad,
81+
filter_grad);
82+
return;
83+
}
84+
#endif
85+
5586
// update padding and dilation
5687
auto in_dims = input.dims();
5788
auto filter_dims = filter.dims();

paddle/phi/kernels/gpu/depthwise_conv_kernel.cu

+19
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,25 @@ void DepthwiseConvKernel(const Context& dev_ctx,
7272
input.dims()[1]));
7373
}
7474

75+
// Enable if cudnn above 8.2, hip already has cudnn kernel.
76+
#if defined(CUDNN_VERSION) && CUDNN_VERSION_MIN(8, 2, 0) && \
77+
!defined(PADDLE_WITH_HIP)
78+
DWConvParams params(has_fuse_relu, data_format, strides, dilations);
79+
if (params.UseCudnnDepthwise(input, filter)) {
80+
phi::DepthwiseConvCudnnKernel<T>(dev_ctx,
81+
input,
82+
filter,
83+
strides_t,
84+
paddings_t,
85+
padding_algorithm,
86+
groups,
87+
dilations_t,
88+
data_format,
89+
out);
90+
return;
91+
}
92+
#endif
93+
7594
// update padding and dilation
7695
auto in_dims = input.dims();
7796
auto filter_dims = filter.dims();
+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/kernels/conv_kernel.h"
18+
19+
#include "paddle/phi/backends/context_pool.h"
20+
#include "paddle/phi/backends/gpu/gpu_context.h"
21+
#include "paddle/phi/core/dense_tensor.h"
22+
#include "paddle/phi/core/kernel_registry.h"
23+
24+
#ifdef PADDLE_WITH_HIP
25+
#include "paddle/phi/kernels/gpudnn/conv_miopen_helper.h"
26+
#else
27+
#include "paddle/phi/kernels/gpudnn/conv_cudnn_v7.h"
28+
#endif
29+
30+
#include "paddle/phi/common/bfloat16.h"
31+
#include "paddle/phi/common/float16.h"
32+
33+
#ifdef PADDLE_WITH_CUDNN_FRONTEND
34+
// clang-format off
35+
#include "paddle/phi/backends/dynload/cudnn_frontend.h"
36+
#include "paddle/phi/kernels/autotune/cache.h"
37+
#include "paddle/phi/kernels/gpudnn/conv_cudnn_frontend.h"
38+
// clang-format on
39+
#endif
40+
41+
namespace phi {
42+
43+
template <typename T, typename Context>
44+
void ConvCudnnKernel(const Context& ctx,
45+
const DenseTensor& input,
46+
const DenseTensor& filter,
47+
const std::vector<int>& strides,
48+
const std::vector<int>& paddings_t,
49+
const std::string& padding_algorithm,
50+
const std::vector<int>& dilations_t,
51+
int groups,
52+
const std::string& data_format,
53+
DenseTensor* output);
54+
55+
template <typename T, typename Context>
56+
void DepthwiseConvCudnnKernel(const Context& dev_ctx,
57+
const DenseTensor& input,
58+
const DenseTensor& filter,
59+
const std::vector<int>& strides,
60+
const std::vector<int>& paddings,
61+
const std::string& padding_algorithm,
62+
int groups,
63+
const std::vector<int>& dilations,
64+
const std::string& data_format,
65+
DenseTensor* out) {
66+
ConvCudnnKernel<T>(dev_ctx,
67+
input,
68+
filter,
69+
strides,
70+
paddings,
71+
padding_algorithm,
72+
dilations,
73+
groups,
74+
data_format,
75+
out);
76+
}
77+
78+
template <typename T, typename Context>
79+
void ConvCudnnGradKernel(const Context& ctx,
80+
const DenseTensor& input,
81+
const DenseTensor& filter,
82+
const DenseTensor& output_grad,
83+
const std::vector<int>& strides_t,
84+
const std::vector<int>& paddings_t,
85+
const std::string& padding_algorithm,
86+
const std::vector<int>& dilations_t,
87+
int groups,
88+
const std::string& data_format,
89+
DenseTensor* input_grad,
90+
DenseTensor* filter_grad);
91+
92+
template <typename T, typename Context>
93+
void DepthwiseConvCudnnGradKernel(const Context& dev_ctx,
94+
const DenseTensor& input,
95+
const DenseTensor& filter,
96+
const DenseTensor& out_grad,
97+
const std::vector<int>& strides,
98+
const std::vector<int>& paddings,
99+
const std::string& padding_algorithm,
100+
int groups,
101+
const std::vector<int>& dilations,
102+
const std::string& data_format,
103+
DenseTensor* input_grad,
104+
DenseTensor* filter_grad) {
105+
ConvCudnnGradKernel<T>(dev_ctx,
106+
input,
107+
filter,
108+
out_grad,
109+
strides,
110+
paddings,
111+
padding_algorithm,
112+
dilations,
113+
groups,
114+
data_format,
115+
input_grad,
116+
filter_grad);
117+
}
118+
119+
} // namespace phi

paddle/phi/kernels/gpudnn/conv_grad_kernel.cu

+1-31
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/conv_grad_kernel.h"
16+
#include "paddle/phi/kernels/gpudnn/conv_gpudnn.h"
1617

1718
#include "glog/logging.h"
1819

@@ -759,37 +760,6 @@ void Conv3DCudnnGradKernel(const Context& dev_ctx,
759760
filter_grad);
760761
}
761762

762-
template <typename T, typename Context>
763-
void DepthwiseConvCudnnGradKernel(const Context& dev_ctx,
764-
const DenseTensor& input,
765-
const DenseTensor& filter,
766-
const DenseTensor& out_grad,
767-
const std::vector<int>& strides,
768-
const std::vector<int>& paddings,
769-
const std::string& padding_algorithm,
770-
int groups,
771-
const std::vector<int>& dilations,
772-
const std::string& data_format,
773-
bool use_addto,
774-
int workspace_size_MB,
775-
bool exhaustive_search,
776-
bool fuse_relu,
777-
DenseTensor* input_grad,
778-
DenseTensor* filter_grad) {
779-
ConvCudnnGradKernel<T>(dev_ctx,
780-
input,
781-
filter,
782-
out_grad,
783-
strides,
784-
paddings,
785-
padding_algorithm,
786-
dilations,
787-
groups,
788-
data_format,
789-
input_grad,
790-
filter_grad);
791-
}
792-
793763
template <typename T, typename Context>
794764
void ConvCudnnGradGradKernel(
795765
const Context& ctx,

paddle/phi/kernels/gpudnn/conv_kernel.cu

+1-24
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/conv_kernel.h"
16+
#include "paddle/phi/kernels/gpudnn/conv_gpudnn.h"
1617

1718
#include "glog/logging.h"
1819

@@ -557,30 +558,6 @@ void Conv3DCudnnKernel(const Context& dev_ctx,
557558
data_format,
558559
out);
559560
}
560-
561-
template <typename T, typename Context>
562-
void DepthwiseConvCudnnKernel(const Context& dev_ctx,
563-
const DenseTensor& input,
564-
const DenseTensor& filter,
565-
const std::vector<int>& strides,
566-
const std::vector<int>& paddings,
567-
const std::string& padding_algorithm,
568-
int groups,
569-
const std::vector<int>& dilations,
570-
const std::string& data_format,
571-
DenseTensor* out) {
572-
ConvCudnnKernel<T>(dev_ctx,
573-
input,
574-
filter,
575-
strides,
576-
paddings,
577-
padding_algorithm,
578-
dilations,
579-
groups,
580-
data_format,
581-
out);
582-
}
583-
584561
} // namespace phi
585562

586563
#ifdef PADDLE_WITH_HIP

0 commit comments

Comments
 (0)