Skip to content

Commit a84a580

Browse files
authored
Add CUDA kernel for prior_box_op. (#9553)
1 parent d139f2c commit a84a580

File tree

4 files changed

+207
-68
lines changed

4 files changed

+207
-68
lines changed

paddle/fluid/operators/prior_box_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class PriorBoxOp : public framework::OperatorWithKernel {
7373
const framework::ExecutionContext& ctx) const override {
7474
return framework::OpKernelType(
7575
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
76-
platform::CPUPlace());
76+
ctx.device_context());
7777
}
7878
};
7979

@@ -171,6 +171,5 @@ namespace ops = paddle::operators;
171171
REGISTER_OPERATOR(prior_box, ops::PriorBoxOp, ops::PriorBoxOpMaker,
172172
paddle::framework::EmptyGradOpMaker);
173173

174-
REGISTER_OP_CPU_KERNEL(
175-
prior_box, ops::PriorBoxOpKernel<paddle::platform::CPUPlace, float>,
176-
ops::PriorBoxOpKernel<paddle::platform::CPUPlace, double>);
174+
REGISTER_OP_CPU_KERNEL(prior_box, ops::PriorBoxOpKernel<float>,
175+
ops::PriorBoxOpKernel<double>);
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/prior_box_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
__device__ inline T clip(T in) {
22+
return min(max(in, 0.), 1.);
23+
}
24+
25+
template <typename T>
26+
__global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
27+
const int width, const int im_height,
28+
const int im_width, const int as_num,
29+
const T offset, const T step_width,
30+
const T step_height, const T* min_sizes,
31+
const T* max_sizes, const int min_num,
32+
bool is_clip) {
33+
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
34+
int box_num = height * width * num_priors;
35+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
36+
i += blockDim.x * gridDim.x) {
37+
int h = i / (num_priors * width);
38+
int w = (i / num_priors) % width;
39+
int p = i % num_priors;
40+
int m = max_sizes ? p / (as_num + 1) : p / as_num;
41+
T cx = (w + offset) * step_width;
42+
T cy = (h + offset) * step_height;
43+
T bw, bh;
44+
T min_size = min_sizes[m];
45+
if (max_sizes) {
46+
int s = p % (as_num + 1);
47+
if (s < as_num) {
48+
T ar = aspect_ratios[s];
49+
bw = min_size * sqrt(ar) / 2.;
50+
bh = min_size / sqrt(ar) / 2.;
51+
} else {
52+
T max_size = max_sizes[m];
53+
bw = sqrt(min_size * max_size) / 2.;
54+
bh = bw;
55+
}
56+
} else {
57+
int s = p % as_num;
58+
T ar = aspect_ratios[s];
59+
bw = min_size * sqrt(ar) / 2.;
60+
bh = min_size / sqrt(ar) / 2.;
61+
}
62+
T xmin = (cx - bw) / im_width;
63+
T ymin = (cy - bh) / im_height;
64+
T xmax = (cx + bw) / im_width;
65+
T ymax = (cy + bh) / im_height;
66+
out[i * 4] = is_clip ? clip<T>(xmin) : xmin;
67+
out[i * 4 + 1] = is_clip ? clip<T>(ymin) : ymin;
68+
out[i * 4 + 2] = is_clip ? clip<T>(xmax) : xmax;
69+
out[i * 4 + 3] = is_clip ? clip<T>(ymax) : ymax;
70+
}
71+
}
72+
73+
template <typename T>
74+
__global__ void SetVariance(T* out, const T* var, const int vnum,
75+
const int num) {
76+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
77+
i += blockDim.x * gridDim.x) {
78+
out[i] = var[i % vnum];
79+
}
80+
}
81+
82+
template <typename T>
83+
class PriorBoxOpCUDAKernel : public framework::OpKernel<T> {
84+
public:
85+
void Compute(const framework::ExecutionContext& ctx) const override {
86+
auto* input = ctx.Input<paddle::framework::Tensor>("Input");
87+
auto* image = ctx.Input<paddle::framework::Tensor>("Image");
88+
auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes");
89+
auto* vars = ctx.Output<paddle::framework::Tensor>("Variances");
90+
91+
auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes");
92+
auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes");
93+
auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios");
94+
auto variances = ctx.Attr<std::vector<float>>("variances");
95+
auto flip = ctx.Attr<bool>("flip");
96+
auto clip = ctx.Attr<bool>("clip");
97+
98+
std::vector<float> aspect_ratios;
99+
ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios);
100+
101+
T step_w = static_cast<T>(ctx.Attr<float>("step_w"));
102+
T step_h = static_cast<T>(ctx.Attr<float>("step_h"));
103+
T offset = static_cast<T>(ctx.Attr<float>("offset"));
104+
105+
auto im_width = image->dims()[3];
106+
auto im_height = image->dims()[2];
107+
108+
auto width = input->dims()[3];
109+
auto height = input->dims()[2];
110+
111+
T step_width, step_height;
112+
if (step_w == 0 || step_h == 0) {
113+
step_width = static_cast<T>(im_width) / width;
114+
step_height = static_cast<T>(im_height) / height;
115+
} else {
116+
step_width = step_w;
117+
step_height = step_h;
118+
}
119+
120+
int num_priors = aspect_ratios.size() * min_sizes.size();
121+
if (max_sizes.size() > 0) {
122+
num_priors += max_sizes.size();
123+
}
124+
int min_num = static_cast<int>(min_sizes.size());
125+
int box_num = width * height * num_priors;
126+
127+
int block = 512;
128+
int grid = (box_num + block - 1) / block;
129+
130+
auto stream =
131+
ctx.template device_context<platform::CUDADeviceContext>().stream();
132+
133+
boxes->mutable_data<T>(ctx.GetPlace());
134+
vars->mutable_data<T>(ctx.GetPlace());
135+
136+
framework::Tensor r;
137+
framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r);
138+
139+
framework::Tensor min;
140+
framework::TensorFromVector(min_sizes, ctx.device_context(), &min);
141+
142+
T* max_data = nullptr;
143+
framework::Tensor max;
144+
if (max_sizes.size() > 0) {
145+
framework::TensorFromVector(max_sizes, ctx.device_context(), &max);
146+
max_data = max.data<T>();
147+
}
148+
149+
GenPriorBox<T><<<grid, block, 0, stream>>>(
150+
boxes->data<T>(), r.data<T>(), height, width, im_height, im_width,
151+
aspect_ratios.size(), offset, step_width, step_height, min.data<T>(),
152+
max_data, min_num, clip);
153+
154+
framework::Tensor v;
155+
framework::TensorFromVector(variances, ctx.device_context(), &v);
156+
grid = (box_num * 4 + block - 1) / block;
157+
SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(),
158+
variances.size(), box_num * 4);
159+
}
160+
}; // namespace operators
161+
162+
} // namespace operators
163+
} // namespace paddle
164+
165+
namespace ops = paddle::operators;
166+
REGISTER_OP_CUDA_KERNEL(prior_box, ops::PriorBoxOpCUDAKernel<float>,
167+
ops::PriorBoxOpCUDAKernel<double>);

paddle/fluid/operators/prior_box_op.h

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct ClipFunctor {
5151
}
5252
};
5353

54-
template <typename Place, typename T>
54+
template <typename T>
5555
class PriorBoxOpKernel : public framework::OpKernel<T> {
5656
public:
5757
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -106,49 +106,24 @@ class PriorBoxOpKernel : public framework::OpKernel<T> {
106106
int idx = 0;
107107
for (size_t s = 0; s < min_sizes.size(); ++s) {
108108
auto min_size = min_sizes[s];
109-
// first prior: aspect_ratio = 1, size = min_size
110-
box_width = box_height = min_size / 2.;
111-
// xmin
112-
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
113-
// ymin
114-
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
115-
// xmax
116-
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
117-
// ymax
118-
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
119-
120-
idx++;
121-
if (max_sizes.size() > 0) {
122-
auto max_size = max_sizes[s];
123-
// second prior: aspect_ratio = 1,
124-
// size = sqrt(min_size * max_size)
125-
box_width = box_height = sqrt(min_size * max_size) / 2.;
126-
// xmin
109+
// priors with different aspect ratios
110+
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
111+
float ar = aspect_ratios[r];
112+
box_width = min_size * sqrt(ar) / 2.;
113+
box_height = min_size / sqrt(ar) / 2.;
127114
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
128-
// ymin
129115
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
130-
// xmax
131116
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
132-
// ymax
133117
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
134118
idx++;
135119
}
136-
137-
// rest of priors
138-
for (size_t r = 0; r < aspect_ratios.size(); ++r) {
139-
float ar = aspect_ratios[r];
140-
if (fabs(ar - 1.) < 1e-6) {
141-
continue;
142-
}
143-
box_width = min_size * sqrt(ar) / 2.;
144-
box_height = min_size / sqrt(ar) / 2.;
145-
// xmin
120+
if (max_sizes.size() > 0) {
121+
auto max_size = max_sizes[s];
122+
// square prior with size sqrt(minSize * maxSize)
123+
box_width = box_height = sqrt(min_size * max_size) / 2.;
146124
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
147-
// ymin
148125
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
149-
// xmax
150126
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
151-
// ymax
152127
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
153128
idx++;
154129
}

python/paddle/fluid/tests/unittests/test_prior_box_op.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def set_data(self):
2828

2929
self.attrs = {
3030
'min_sizes': self.min_sizes,
31-
'max_sizes': self.max_sizes,
3231
'aspect_ratios': self.aspect_ratios,
3332
'variances': self.variances,
3433
'flip': self.flip,
@@ -37,25 +36,28 @@ def set_data(self):
3736
'step_h': self.step_h,
3837
'offset': self.offset
3938
}
39+
if len(self.max_sizes) > 0:
40+
self.attrs['max_sizes'] = self.max_sizes
4041

4142
self.outputs = {'Boxes': self.out_boxes, 'Variances': self.out_var}
4243

4344
def test_check_output(self):
4445
self.check_output()
4546

46-
def test_check_grad(self):
47-
return
48-
4947
def setUp(self):
5048
self.op_type = "prior_box"
5149
self.set_data()
5250

51+
def set_max_sizes(self):
52+
max_sizes = [5, 10]
53+
self.max_sizes = np.array(max_sizes).astype('float32').tolist()
54+
5355
def init_test_params(self):
54-
self.layer_w = 4
55-
self.layer_h = 4
56+
self.layer_w = 32
57+
self.layer_h = 32
5658

57-
self.image_w = 20
58-
self.image_h = 20
59+
self.image_w = 40
60+
self.image_h = 40
5961

6062
self.step_w = float(self.image_w) / float(self.layer_w)
6163
self.step_h = float(self.image_h) / float(self.layer_h)
@@ -66,8 +68,7 @@ def init_test_params(self):
6668

6769
self.min_sizes = [2, 4]
6870
self.min_sizes = np.array(self.min_sizes).astype('float32').tolist()
69-
self.max_sizes = [5, 10]
70-
self.max_sizes = np.array(self.max_sizes).astype('float32').tolist()
71+
self.set_max_sizes()
7172
self.aspect_ratios = [2.0, 3.0]
7273
self.flip = True
7374
self.real_aspect_ratios = [1, 2.0, 1.0 / 2.0, 3.0, 1.0 / 3.0]
@@ -79,7 +80,7 @@ def init_test_params(self):
7980
self.clip = True
8081

8182
self.num_priors = len(self.real_aspect_ratios) * len(self.min_sizes)
82-
if len(self.max_sizes) > 1:
83+
if len(self.max_sizes) > 0:
8384
self.num_priors += len(self.max_sizes)
8485
self.offset = 0.5
8586

@@ -105,35 +106,27 @@ def init_test_output(self):
105106
idx = 0
106107
for s in range(len(self.min_sizes)):
107108
min_size = self.min_sizes[s]
108-
c_w = c_h = min_size / 2.
109-
out_boxes[h, w, idx, :] = [
110-
(c_x - c_w) / self.image_w, (c_y - c_h) / self.image_h,
111-
(c_x + c_w) / self.image_w, (c_y + c_h) / self.image_h
112-
]
113-
idx += 1
114-
115-
if len(self.max_sizes) > 0:
116-
max_size = self.max_sizes[s]
117-
# second prior: aspect_ratio = 1,
118-
c_w = c_h = math.sqrt(min_size * max_size) / 2
109+
# rest of priors
110+
for r in range(len(self.real_aspect_ratios)):
111+
ar = self.real_aspect_ratios[r]
112+
c_w = min_size * math.sqrt(ar) / 2
113+
c_h = (min_size / math.sqrt(ar)) / 2
119114
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
120115
(c_y - c_h) / self.image_h,
121116
(c_x + c_w) / self.image_w,
122117
(c_y + c_h) / self.image_h]
123118
idx += 1
124119

125-
# rest of priors
126-
for r in range(len(self.real_aspect_ratios)):
127-
ar = self.real_aspect_ratios[r]
128-
if math.fabs(ar - 1.) < 1e-6:
129-
continue
130-
c_w = min_size * math.sqrt(ar) / 2
131-
c_h = (min_size / math.sqrt(ar)) / 2
120+
if len(self.max_sizes) > 0:
121+
max_size = self.max_sizes[s]
122+
# second prior: aspect_ratio = 1,
123+
c_w = c_h = math.sqrt(min_size * max_size) / 2
132124
out_boxes[h, w, idx, :] = [(c_x - c_w) / self.image_w,
133125
(c_y - c_h) / self.image_h,
134126
(c_x + c_w) / self.image_w,
135127
(c_y + c_h) / self.image_h]
136128
idx += 1
129+
137130
# clip the prior's coordidate such that it is within[0, 1]
138131
if self.clip:
139132
out_boxes = np.clip(out_boxes, 0.0, 1.0)
@@ -144,5 +137,10 @@ def init_test_output(self):
144137
self.out_var = out_var.astype('float32')
145138

146139

140+
class TestPriorBoxOpWithMaxSize(TestPriorBoxOp):
141+
def set_max_sizes(self):
142+
self.max_sizes = []
143+
144+
147145
if __name__ == '__main__':
148146
unittest.main()

0 commit comments

Comments
 (0)