Skip to content

Commit 1c6d064

Browse files
authored
add collect fpn proposals op,test=develop (#16074)
* add collect fpn proposals op,test=develop
1 parent 60be66e commit 1c6d064

File tree

9 files changed

+667
-4
lines changed

9 files changed

+667
-4
lines changed

paddle/fluid/API.spec

+1
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ paddle.fluid.layers.box_clip (ArgSpec(args=['input', 'im_info', 'name'], varargs
360360
paddle.fluid.layers.multiclass_nms (ArgSpec(args=['bboxes', 'scores', 'score_threshold', 'nms_top_k', 'keep_top_k', 'nms_threshold', 'normalized', 'nms_eta', 'background_label', 'name'], varargs=None, keywords=None, defaults=(0.3, True, 1.0, 0, None)), ('document', 'ca7d1107b6c5d2d6d8221039a220fde0'))
361361
paddle.fluid.layers.distribute_fpn_proposals (ArgSpec(args=['fpn_rois', 'min_level', 'max_level', 'refer_level', 'refer_scale', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '7bb011ec26bace2bc23235aa4a17647d'))
362362
paddle.fluid.layers.box_decoder_and_assign (ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'box_score', 'box_clip', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'dfc953994fd8fef35c49dd9c6eea37a5'))
363+
paddle.fluid.layers.collect_fpn_proposals (ArgSpec(args=['multi_rois', 'multi_scores', 'min_level', 'max_level', 'post_nms_top_n', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '82ffd896ecc3c005ae1cad40854dcace'))
363364
paddle.fluid.layers.accuracy (ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)), ('document', '9808534c12c5e739a10f73ebb0b4eafd'))
364365
paddle.fluid.layers.auc (ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)), ('document', 'e0e95334fce92d16c2d9db6e7caffc47'))
365366
paddle.fluid.layers.exponential_decay (ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)), ('document', '98a5050bee8522fcea81aa795adaba51'))

paddle/fluid/operators/detection/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo
3939
if(WITH_GPU)
4040
detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS memory cub)
4141
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc distribute_fpn_proposals_op.cu DEPS memory cub)
42+
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc collect_fpn_proposals_op.cu DEPS memory cub)
4243
else()
4344
detection_library(generate_proposals_op SRCS generate_proposals_op.cc)
4445
detection_library(distribute_fpn_proposals_op SRCS distribute_fpn_proposals_op.cc)
46+
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
4547
endif()
4648

4749
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)

paddle/fluid/operators/detection/bbox_util.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace paddle {
2222
namespace operators {
2323

2424
struct RangeInitFunctor {
25-
int start_;
26-
int delta_;
27-
int* out_;
28-
HOSTDEVICE void operator()(size_t i) { out_[i] = start_ + i * delta_; }
25+
int start;
26+
int delta;
27+
int* out;
28+
HOSTDEVICE void operator()(size_t i) { out[i] = start + i * delta; }
2929
};
3030

3131
template <typename T>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License.*/
11+
12+
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
using Tensor = framework::Tensor;
18+
using LoDTensor = framework::LoDTensor;
19+
class CollectFpnProposalsOp : public framework::OperatorWithKernel {
20+
public:
21+
using framework::OperatorWithKernel::OperatorWithKernel;
22+
23+
void InferShape(framework::InferShapeContext *context) const override {
24+
PADDLE_ENFORCE(context->HasInputs("MultiLevelRois"),
25+
"Inputs(MultiLevelRois) shouldn't be null");
26+
PADDLE_ENFORCE(context->HasInputs("MultiLevelScores"),
27+
"Inputs(MultiLevelScores) shouldn't be null");
28+
PADDLE_ENFORCE(context->HasOutput("FpnRois"),
29+
"Outputs(MultiFpnRois) of DistributeOp should not be null");
30+
auto roi_dims = context->GetInputsDim("MultiLevelRois");
31+
auto score_dims = context->GetInputsDim("MultiLevelScores");
32+
auto post_nms_topN = context->Attrs().Get<int>("post_nms_topN");
33+
std::vector<int64_t> out_dims;
34+
for (auto &roi_dim : roi_dims) {
35+
PADDLE_ENFORCE_EQ(roi_dim[1], 4,
36+
"Second dimension of Input(MultiLevelRois) must be 4");
37+
}
38+
for (auto &score_dim : score_dims) {
39+
PADDLE_ENFORCE_EQ(
40+
score_dim[1], 1,
41+
"Second dimension of Input(MultiLevelScores) must be 1");
42+
}
43+
context->SetOutputDim("FpnRois", {post_nms_topN, 4});
44+
if (!context->IsRuntime()) { // Runtime LoD infershape will be computed
45+
// in Kernel.
46+
context->ShareLoD("MultiLevelRois", "FpnRois");
47+
}
48+
if (context->IsRuntime()) {
49+
std::vector<framework::InferShapeVarPtr> roi_inputs =
50+
context->GetInputVarPtrs("MultiLevelRois");
51+
std::vector<framework::InferShapeVarPtr> score_inputs =
52+
context->GetInputVarPtrs("MultiLevelScores");
53+
for (size_t i = 0; i < roi_inputs.size(); ++i) {
54+
framework::Variable *roi_var =
55+
boost::get<framework::Variable *>(roi_inputs[i]);
56+
framework::Variable *score_var =
57+
boost::get<framework::Variable *>(score_inputs[i]);
58+
auto &roi_lod = roi_var->Get<LoDTensor>().lod();
59+
auto &score_lod = score_var->Get<LoDTensor>().lod();
60+
PADDLE_ENFORCE_EQ(roi_lod, score_lod,
61+
"Inputs(MultiLevelRois) and Inputs(MultiLevelScores) "
62+
"should have same lod.");
63+
}
64+
}
65+
}
66+
67+
protected:
68+
framework::OpKernelType GetExpectedKernelType(
69+
const framework::ExecutionContext &ctx) const override {
70+
auto data_type =
71+
framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]);
72+
return framework::OpKernelType(data_type, ctx.GetPlace());
73+
}
74+
};
75+
76+
class CollectFpnProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
77+
public:
78+
void Make() override {
79+
AddInput("MultiLevelRois",
80+
"(LoDTensor) Multiple roi LoDTensors from each level in shape "
81+
"(N, 4), N is the number of RoIs")
82+
.AsDuplicable();
83+
AddInput("MultiLevelScores",
84+
"(LoDTensor) Multiple score LoDTensors from each level in shape"
85+
" (N, 1), N is the number of RoIs.")
86+
.AsDuplicable();
87+
AddOutput("FpnRois", "(LoDTensor) All selected RoIs with highest scores");
88+
AddAttr<int>("post_nms_topN",
89+
"Select post_nms_topN RoIs from"
90+
" all images and all fpn layers");
91+
AddComment(R"DOC(
92+
This operator concats all proposals from different images
93+
and different FPN levels. Then sort all of those proposals
94+
by objectness confidence. Select the post_nms_topN RoIs in
95+
total. Finally, re-sort the RoIs in the order of batch index.
96+
)DOC");
97+
}
98+
};
99+
} // namespace operators
100+
} // namespace paddle
101+
102+
namespace ops = paddle::operators;
103+
REGISTER_OPERATOR(collect_fpn_proposals, ops::CollectFpnProposalsOp,
104+
ops::CollectFpnProposalsOpMaker,
105+
paddle::framework::EmptyGradOpMaker);
106+
REGISTER_OP_CPU_KERNEL(collect_fpn_proposals,
107+
ops::CollectFpnProposalsOpKernel<float>,
108+
ops::CollectFpnProposalsOpKernel<double>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <paddle/fluid/memory/allocation/allocator.h>
13+
#include "cub/cub.cuh"
14+
#include "paddle/fluid/framework/mixed_vector.h"
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/memory/memcpy.h"
17+
#include "paddle/fluid/operators/detection/bbox_util.h"
18+
#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h"
19+
#include "paddle/fluid/operators/gather.cu.h"
20+
#include "paddle/fluid/operators/math/concat_and_split.h"
21+
#include "paddle/fluid/operators/strided_memcpy.h"
22+
#include "paddle/fluid/platform/cuda_primitives.h"
23+
#include "paddle/fluid/platform/for_range.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
using Tensor = framework::Tensor;
29+
using LoDTensor = framework::LoDTensor;
30+
31+
static constexpr int kNumCUDAThreads = 64;
32+
static constexpr int kNumMaxinumNumBlocks = 4096;
33+
34+
const int kBBoxSize = 4;
35+
36+
static inline int NumBlocks(const int N) {
37+
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
38+
kNumMaxinumNumBlocks);
39+
}
40+
41+
static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids,
42+
int* length_lod) {
43+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads);
44+
i += blockDim.x * gridDim.x) {
45+
platform::CudaAtomicAdd(length_lod + batch_ids[i], 1);
46+
}
47+
}
48+
49+
template <typename DeviceContext, typename T>
50+
class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
51+
public:
52+
void Compute(const framework::ExecutionContext& ctx) const override {
53+
const auto roi_ins = ctx.MultiInput<LoDTensor>("MultiLevelRois");
54+
const auto score_ins = ctx.MultiInput<LoDTensor>("MultiLevelScores");
55+
auto fpn_rois = ctx.Output<LoDTensor>("FpnRois");
56+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
57+
58+
const int post_nms_topN = ctx.Attr<int>("post_nms_topN");
59+
60+
// concat inputs along axis = 0
61+
int roi_offset = 0;
62+
int score_offset = 0;
63+
int total_roi_num = 0;
64+
for (size_t i = 0; i < roi_ins.size(); ++i) {
65+
total_roi_num += roi_ins[i]->dims()[0];
66+
}
67+
68+
int real_post_num = min(post_nms_topN, total_roi_num);
69+
fpn_rois->mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
70+
Tensor concat_rois;
71+
Tensor concat_scores;
72+
T* concat_rois_data = concat_rois.mutable_data<T>(
73+
{total_roi_num, kBBoxSize}, dev_ctx.GetPlace());
74+
T* concat_scores_data =
75+
concat_scores.mutable_data<T>({total_roi_num, 1}, dev_ctx.GetPlace());
76+
Tensor roi_batch_id_list;
77+
roi_batch_id_list.Resize({total_roi_num});
78+
int* roi_batch_id_data =
79+
roi_batch_id_list.mutable_data<int>(platform::CPUPlace());
80+
int index = 0;
81+
int lod_size;
82+
auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace());
83+
84+
for (size_t i = 0; i < roi_ins.size(); ++i) {
85+
auto roi_in = roi_ins[i];
86+
auto score_in = score_ins[i];
87+
auto roi_lod = roi_in->lod().back();
88+
lod_size = roi_lod.size() - 1;
89+
for (size_t n = 0; n < lod_size; ++n) {
90+
for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) {
91+
roi_batch_id_data[index++] = n;
92+
}
93+
}
94+
95+
memory::Copy(place, concat_rois_data + roi_offset, place,
96+
roi_in->data<T>(), roi_in->numel() * sizeof(T),
97+
dev_ctx.stream());
98+
memory::Copy(place, concat_scores_data + score_offset, place,
99+
score_in->data<T>(), score_in->numel() * sizeof(T),
100+
dev_ctx.stream());
101+
roi_offset += roi_in->numel();
102+
score_offset += score_in->numel();
103+
}
104+
105+
// copy batch id list to GPU
106+
Tensor roi_batch_id_list_gpu;
107+
framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(),
108+
&roi_batch_id_list_gpu);
109+
110+
Tensor index_in_t;
111+
int* idx_in =
112+
index_in_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
113+
platform::ForRange<platform::CUDADeviceContext> for_range_total(
114+
dev_ctx, total_roi_num);
115+
for_range_total(RangeInitFunctor{0, 1, idx_in});
116+
117+
Tensor keys_out_t;
118+
T* keys_out =
119+
keys_out_t.mutable_data<T>({total_roi_num}, dev_ctx.GetPlace());
120+
Tensor index_out_t;
121+
int* idx_out =
122+
index_out_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace());
123+
124+
// Determine temporary device storage requirements
125+
size_t temp_storage_bytes = 0;
126+
cub::DeviceRadixSort::SortPairsDescending<T, int>(
127+
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
128+
idx_out, total_roi_num);
129+
// Allocate temporary storage
130+
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes,
131+
memory::Allocator::kScratchpad);
132+
133+
// Run sorting operation
134+
// sort score to get corresponding index
135+
cub::DeviceRadixSort::SortPairsDescending<T, int>(
136+
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
137+
keys_out, idx_in, idx_out, total_roi_num);
138+
index_out_t.Resize({real_post_num});
139+
Tensor sorted_rois;
140+
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
141+
Tensor sorted_batch_id;
142+
sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
143+
GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois);
144+
GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t,
145+
&sorted_batch_id);
146+
147+
Tensor batch_index_t;
148+
int* batch_idx_in =
149+
batch_index_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
150+
platform::ForRange<platform::CUDADeviceContext> for_range_post(
151+
dev_ctx, real_post_num);
152+
for_range_post(RangeInitFunctor{0, 1, batch_idx_in});
153+
154+
Tensor out_id_t;
155+
int* out_id_data =
156+
out_id_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace());
157+
// Determine temporary device storage requirements
158+
temp_storage_bytes = 0;
159+
cub::DeviceRadixSort::SortPairs<int, int>(
160+
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
161+
batch_idx_in, index_out_t.data<int>(), real_post_num);
162+
// Allocate temporary storage
163+
d_temp_storage = memory::Alloc(place, temp_storage_bytes,
164+
memory::Allocator::kScratchpad);
165+
166+
// Run sorting operation
167+
// sort batch_id to get corresponding index
168+
cub::DeviceRadixSort::SortPairs<int, int>(
169+
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
170+
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
171+
172+
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
173+
174+
Tensor length_lod;
175+
int* length_lod_data =
176+
length_lod.mutable_data<int>({lod_size}, dev_ctx.GetPlace());
177+
math::SetConstant<platform::CUDADeviceContext, int> set_zero;
178+
set_zero(dev_ctx, &length_lod, static_cast<int>(0));
179+
180+
int blocks = NumBlocks(real_post_num);
181+
int threads = kNumCUDAThreads;
182+
183+
// get length-based lod by batch ids
184+
GetLengthLoD<<<blocks, threads>>>(real_post_num, out_id_data,
185+
length_lod_data);
186+
std::vector<int> length_lod_cpu(lod_size);
187+
memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place,
188+
length_lod_data, sizeof(int) * lod_size, dev_ctx.stream());
189+
dev_ctx.Wait();
190+
191+
std::vector<size_t> offset(1, 0);
192+
for (int i = 0; i < lod_size; ++i) {
193+
offset.emplace_back(offset.back() + length_lod_cpu[i]);
194+
}
195+
196+
framework::LoD lod;
197+
lod.emplace_back(offset);
198+
fpn_rois->set_lod(lod);
199+
}
200+
};
201+
202+
} // namespace operators
203+
} // namespace paddle
204+
205+
namespace ops = paddle::operators;
206+
REGISTER_OP_CUDA_KERNEL(
207+
collect_fpn_proposals,
208+
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
209+
float>,
210+
ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext,
211+
double>);

0 commit comments

Comments
 (0)