Skip to content

Commit 893789a

Browse files
authored
Merge pull request #16050 from jerrywgz/add_box_decoder_and_assign
Add box decoder and assign
2 parents 045e591 + 072eca3 commit 893789a

File tree

7 files changed

+580
-0
lines changed

7 files changed

+580
-0
lines changed

paddle/fluid/API.spec

+1
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ paddle.fluid.layers.polygon_box_transform (ArgSpec(args=['input', 'name'], varar
329329
paddle.fluid.layers.yolov3_loss (ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'anchor_mask', 'class_num', 'ignore_thresh', 'downsample_ratio', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '991e934c3e09abf0edec7c9c978b4691'))
330330
paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '397e9e02b451d99c56e20f268fa03f2e'))
331331
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
332+
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '005a5ae47d6c8fff721931d69d072b9f'))
332333
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
333334
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
334335
paddle.fluid.layers.exponential_decay (ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)), ('document', '98a5050bee8522fcea81aa795adaba51'))

paddle/fluid/operators/detection/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
3333
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
3434
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
3535
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
36+
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
3637

3738
if(WITH_GPU)
3839
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
using LoDTensor = framework::LoDTensor;
18+
19+
class BoxDecoderAndAssignOp : public framework::OperatorWithKernel {
20+
public:
21+
using framework::OperatorWithKernel::OperatorWithKernel;
22+
23+
protected:
24+
void InferShape(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(
26+
ctx->HasInput("PriorBox"),
27+
"Input(PriorBox) of BoxDecoderAndAssignOp should not be null.");
28+
PADDLE_ENFORCE(
29+
ctx->HasInput("PriorBoxVar"),
30+
"Input(PriorBoxVar) of BoxDecoderAndAssignOp should not be null.");
31+
PADDLE_ENFORCE(
32+
ctx->HasInput("TargetBox"),
33+
"Input(TargetBox) of BoxDecoderAndAssignOp should not be null.");
34+
PADDLE_ENFORCE(
35+
ctx->HasInput("BoxScore"),
36+
"Input(BoxScore) of BoxDecoderAndAssignOp should not be null.");
37+
PADDLE_ENFORCE(
38+
ctx->HasOutput("DecodeBox"),
39+
"Output(DecodeBox) of BoxDecoderAndAssignOp should not be null.");
40+
PADDLE_ENFORCE(
41+
ctx->HasOutput("OutputAssignBox"),
42+
"Output(OutputAssignBox) of BoxDecoderAndAssignOp should not be null.");
43+
44+
auto prior_box_dims = ctx->GetInputDim("PriorBox");
45+
auto prior_box_var_dims = ctx->GetInputDim("PriorBoxVar");
46+
auto target_box_dims = ctx->GetInputDim("TargetBox");
47+
auto box_score_dims = ctx->GetInputDim("BoxScore");
48+
49+
PADDLE_ENFORCE_EQ(prior_box_dims.size(), 2,
50+
"The rank of Input of PriorBox must be 2");
51+
PADDLE_ENFORCE_EQ(prior_box_dims[1], 4, "The shape of PriorBox is [N, 4]");
52+
PADDLE_ENFORCE_EQ(prior_box_var_dims.size(), 1,
53+
"The rank of Input of PriorBoxVar must be 1");
54+
PADDLE_ENFORCE_EQ(prior_box_var_dims[0], 4,
55+
"The shape of PriorBoxVar is [4]");
56+
PADDLE_ENFORCE_EQ(target_box_dims.size(), 2,
57+
"The rank of Input of TargetBox must be 2");
58+
PADDLE_ENFORCE_EQ(box_score_dims.size(), 2,
59+
"The rank of Input of BoxScore must be 2");
60+
PADDLE_ENFORCE_EQ(prior_box_dims[0], target_box_dims[0],
61+
"The first dim of prior_box and target_box is roi nums "
62+
"and should be same!");
63+
PADDLE_ENFORCE_EQ(prior_box_dims[0], box_score_dims[0],
64+
"The first dim of prior_box and box_score is roi nums "
65+
"and should be same!");
66+
PADDLE_ENFORCE_EQ(target_box_dims[1], box_score_dims[1] * prior_box_dims[1],
67+
"The shape of target_box is [N, classnum * 4], The shape "
68+
"of box_score is [N, classnum], The shape of prior_box "
69+
"is [N, 4]");
70+
71+
ctx->SetOutputDim("DecodeBox", framework::make_ddim({target_box_dims[0],
72+
target_box_dims[1]}));
73+
ctx->ShareLoD("TargetBox", /*->*/ "DecodeBox");
74+
ctx->SetOutputDim(
75+
"OutputAssignBox",
76+
framework::make_ddim({prior_box_dims[0], prior_box_dims[1]}));
77+
ctx->ShareLoD("PriorBox", /*->*/ "OutputAssignBox");
78+
}
79+
};
80+
81+
class BoxDecoderAndAssignOpMaker : public framework::OpProtoAndCheckerMaker {
82+
public:
83+
void Make() override {
84+
AddInput(
85+
"PriorBox",
86+
"(Tensor, default Tensor<float>) "
87+
"Box list PriorBox is a 2-D Tensor with shape [N, 4] which holds N "
88+
"boxes and each box is represented as [xmin, ymin, xmax, ymax], "
89+
"[xmin, ymin] is the left top coordinate of the anchor box, "
90+
"if the input is image feature map, they are close to the origin "
91+
"of the coordinate system. [xmax, ymax] is the right bottom "
92+
"coordinate of the anchor box.");
93+
AddInput("PriorBoxVar",
94+
"(Tensor, default Tensor<float>, optional) "
95+
"PriorBoxVar is a 2-D Tensor with shape [N, 4] which holds N "
96+
"group of variance. PriorBoxVar will set all elements to 1 by "
97+
"default.")
98+
.AsDispensable();
99+
AddInput("TargetBox",
100+
"(LoDTensor or Tensor) "
101+
"This input can be a 2-D LoDTensor with shape "
102+
"[N, classnum*4]. It holds N targets for N boxes.");
103+
AddInput("BoxScore",
104+
"(LoDTensor or Tensor) "
105+
"This input can be a 2-D LoDTensor with shape "
106+
"[N, classnum], each box is represented as [classnum] which is "
107+
"the classification probabilities.");
108+
AddAttr<float>("box_clip",
109+
"(float, default 4.135, np.log(1000. / 16.)) "
110+
"clip box to prevent overflowing")
111+
.SetDefault(4.135f);
112+
AddOutput("DecodeBox",
113+
"(LoDTensor or Tensor) "
114+
"the output tensor of op with shape [N, classnum * 4] "
115+
"representing the result of N target boxes decoded with "
116+
"M Prior boxes and variances for each class.");
117+
AddOutput("OutputAssignBox",
118+
"(LoDTensor or Tensor) "
119+
"the output tensor of op with shape [N, 4] "
120+
"representing the result of N target boxes decoded with "
121+
"M Prior boxes and variances with the best non-background class "
122+
"by BoxScore.");
123+
AddComment(R"DOC(
124+
125+
Bounding Box Coder.
126+
127+
Decode the target bounding box with the prior_box information.
128+
129+
The Decoding schema is described below:
130+
131+
$$
132+
ox = (pw \\times pxv \\times tx + px) - \\frac{tw}{2}
133+
$$
134+
$$
135+
oy = (ph \\times pyv \\times ty + py) - \\frac{th}{2}
136+
$$
137+
$$
138+
ow = \\exp (pwv \\times tw) \\times pw + \\frac{tw}{2}
139+
$$
140+
$$
141+
oh = \\exp (phv \\times th) \\times ph + \\frac{th}{2}
142+
$$
143+
144+
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
145+
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
146+
prior_box's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
147+
`phv` denote the variance of the prior_box and `ox`, `oy`, `ow`, `oh` denote the
148+
decoded coordinates, width and height in decode_box.
149+
150+
decode_box is obtained after box decode, then assigning schema is described below:
151+
152+
For each prior_box, use the best non-background class's decoded values to
153+
update the prior_box locations and get output_assign_box. So, the shape of
154+
output_assign_box is the same as PriorBox.
155+
)DOC");
156+
}
157+
};
158+
159+
} // namespace operators
160+
} // namespace paddle
161+
162+
namespace ops = paddle::operators;
163+
REGISTER_OPERATOR(box_decoder_and_assign, ops::BoxDecoderAndAssignOp,
164+
ops::BoxDecoderAndAssignOpMaker,
165+
paddle::framework::EmptyGradOpMaker);
166+
REGISTER_OP_CPU_KERNEL(
167+
box_decoder_and_assign,
168+
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, float>,
169+
ops::BoxDecoderAndAssignKernel<paddle::platform::CPUDeviceContext, double>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/memory/memcpy.h"
13+
#include "paddle/fluid/operators/detection/box_decoder_and_assign_op.h"
14+
#include "paddle/fluid/platform/cuda_primitives.h"
15+
16+
namespace paddle {
17+
namespace operators {
18+
19+
template <typename T>
20+
__global__ void DecodeBoxKernel(const T* prior_box_data,
21+
const T* prior_box_var_data,
22+
const T* target_box_data, const int roi_num,
23+
const int class_num, const T box_clip,
24+
T* output_box_data) {
25+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
26+
if (idx < roi_num * class_num) {
27+
int i = idx / class_num;
28+
int j = idx % class_num;
29+
T prior_box_width = prior_box_data[i * 4 + 2] - prior_box_data[i * 4] + 1;
30+
T prior_box_height =
31+
prior_box_data[i * 4 + 3] - prior_box_data[i * 4 + 1] + 1;
32+
T prior_box_center_x = prior_box_data[i * 4] + prior_box_width / 2;
33+
T prior_box_center_y = prior_box_data[i * 4 + 1] + prior_box_height / 2;
34+
35+
int offset = i * class_num * 4 + j * 4;
36+
T dw = prior_box_var_data[2] * target_box_data[offset + 2];
37+
T dh = prior_box_var_data[3] * target_box_data[offset + 3];
38+
if (dw > box_clip) {
39+
dw = box_clip;
40+
}
41+
if (dh > box_clip) {
42+
dh = box_clip;
43+
}
44+
T target_box_center_x = 0, target_box_center_y = 0;
45+
T target_box_width = 0, target_box_height = 0;
46+
target_box_center_x =
47+
prior_box_var_data[0] * target_box_data[offset] * prior_box_width +
48+
prior_box_center_x;
49+
target_box_center_y =
50+
prior_box_var_data[1] * target_box_data[offset + 1] * prior_box_height +
51+
prior_box_center_y;
52+
target_box_width = expf(dw) * prior_box_width;
53+
target_box_height = expf(dh) * prior_box_height;
54+
55+
output_box_data[offset] = target_box_center_x - target_box_width / 2;
56+
output_box_data[offset + 1] = target_box_center_y - target_box_height / 2;
57+
output_box_data[offset + 2] =
58+
target_box_center_x + target_box_width / 2 - 1;
59+
output_box_data[offset + 3] =
60+
target_box_center_y + target_box_height / 2 - 1;
61+
}
62+
}
63+
64+
template <typename T>
65+
__global__ void AssignBoxKernel(const T* prior_box_data,
66+
const T* box_score_data, T* output_box_data,
67+
const int roi_num, const int class_num,
68+
T* output_assign_box_data) {
69+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
70+
if (idx < roi_num) {
71+
int i = idx;
72+
T max_score = -1;
73+
int max_j = -1;
74+
for (int j = 0; j < class_num; ++j) {
75+
T score = box_score_data[i * class_num + j];
76+
if (score > max_score && j > 0) {
77+
max_score = score;
78+
max_j = j;
79+
}
80+
}
81+
if (max_j > 0) {
82+
for (int pno = 0; pno < 4; pno++) {
83+
output_assign_box_data[i * 4 + pno] =
84+
output_box_data[i * class_num * 4 + max_j * 4 + pno];
85+
}
86+
} else {
87+
for (int pno = 0; pno < 4; pno++) {
88+
output_assign_box_data[i * 4 + pno] = prior_box_data[i * 4 + pno];
89+
}
90+
}
91+
}
92+
}
93+
94+
template <typename DeviceContext, typename T>
95+
class BoxDecoderAndAssignCUDAKernel : public framework::OpKernel<T> {
96+
public:
97+
void Compute(const framework::ExecutionContext& context) const override {
98+
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
99+
"This kernel only runs on GPU device.");
100+
auto* prior_box = context.Input<framework::LoDTensor>("PriorBox");
101+
auto* prior_box_var = context.Input<framework::Tensor>("PriorBoxVar");
102+
auto* target_box = context.Input<framework::LoDTensor>("TargetBox");
103+
auto* box_score = context.Input<framework::LoDTensor>("BoxScore");
104+
auto* output_box = context.Output<framework::Tensor>("DecodeBox");
105+
auto* output_assign_box =
106+
context.Output<framework::Tensor>("OutputAssignBox");
107+
108+
auto roi_num = target_box->dims()[0];
109+
auto class_num = box_score->dims()[1];
110+
auto* target_box_data = target_box->data<T>();
111+
auto* prior_box_data = prior_box->data<T>();
112+
auto* prior_box_var_data = prior_box_var->data<T>();
113+
auto* box_score_data = box_score->data<T>();
114+
output_box->mutable_data<T>({roi_num, class_num * 4}, context.GetPlace());
115+
output_assign_box->mutable_data<T>({roi_num, 4}, context.GetPlace());
116+
T* output_box_data = output_box->data<T>();
117+
T* output_assign_box_data = output_assign_box->data<T>();
118+
119+
int block = 512;
120+
int grid = (roi_num * class_num + block - 1) / block;
121+
auto& device_ctx = context.cuda_device_context();
122+
123+
const T box_clip = context.Attr<T>("box_clip");
124+
125+
DecodeBoxKernel<T><<<grid, block, 0, device_ctx.stream()>>>(
126+
prior_box_data, prior_box_var_data, target_box_data, roi_num, class_num,
127+
box_clip, output_box_data);
128+
129+
context.device_context().Wait();
130+
int assign_grid = (roi_num + block - 1) / block;
131+
AssignBoxKernel<T><<<assign_grid, block, 0, device_ctx.stream()>>>(
132+
prior_box_data, box_score_data, output_box_data, roi_num, class_num,
133+
output_assign_box_data);
134+
context.device_context().Wait();
135+
}
136+
};
137+
138+
} // namespace operators
139+
} // namespace paddle
140+
141+
namespace ops = paddle::operators;
142+
REGISTER_OP_CUDA_KERNEL(
143+
box_decoder_and_assign,
144+
ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext,
145+
float>,
146+
ops::BoxDecoderAndAssignCUDAKernel<paddle::platform::CUDADeviceContext,
147+
double>);

0 commit comments

Comments
 (0)