|
| 1 | +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. |
| 2 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +you may not use this file except in compliance with the License. |
| 4 | +You may obtain a copy of the License at |
| 5 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | + Unless required by applicable law or agreed to in writing, software |
| 7 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +See the License for the specific language governing permissions and |
| 10 | +limitations under the License. */ |
| 11 | + |
| 12 | +#include <paddle/fluid/memory/allocation/allocator.h> |
| 13 | +#include "cub/cub.cuh" |
| 14 | +#include "paddle/fluid/framework/mixed_vector.h" |
| 15 | +#include "paddle/fluid/framework/op_registry.h" |
| 16 | +#include "paddle/fluid/memory/memcpy.h" |
| 17 | +#include "paddle/fluid/operators/detection/bbox_util.h" |
| 18 | +#include "paddle/fluid/operators/detection/collect_fpn_proposals_op.h" |
| 19 | +#include "paddle/fluid/operators/gather.cu.h" |
| 20 | +#include "paddle/fluid/operators/math/concat_and_split.h" |
| 21 | +#include "paddle/fluid/operators/strided_memcpy.h" |
| 22 | +#include "paddle/fluid/platform/cuda_primitives.h" |
| 23 | +#include "paddle/fluid/platform/for_range.h" |
| 24 | + |
| 25 | +namespace paddle { |
| 26 | +namespace operators { |
| 27 | + |
| 28 | +using Tensor = framework::Tensor; |
| 29 | +using LoDTensor = framework::LoDTensor; |
| 30 | + |
| 31 | +static constexpr int kNumCUDAThreads = 64; |
| 32 | +static constexpr int kNumMaxinumNumBlocks = 4096; |
| 33 | + |
| 34 | +const int kBBoxSize = 4; |
| 35 | + |
| 36 | +static inline int NumBlocks(const int N) { |
| 37 | + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, |
| 38 | + kNumMaxinumNumBlocks); |
| 39 | +} |
| 40 | + |
| 41 | +static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids, |
| 42 | + int* length_lod) { |
| 43 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads); |
| 44 | + i += blockDim.x * gridDim.x) { |
| 45 | + platform::CudaAtomicAdd(length_lod + batch_ids[i], 1); |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +template <typename DeviceContext, typename T> |
| 50 | +class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> { |
| 51 | + public: |
| 52 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 53 | + const auto roi_ins = ctx.MultiInput<LoDTensor>("MultiLevelRois"); |
| 54 | + const auto score_ins = ctx.MultiInput<LoDTensor>("MultiLevelScores"); |
| 55 | + auto fpn_rois = ctx.Output<LoDTensor>("FpnRois"); |
| 56 | + auto& dev_ctx = ctx.template device_context<DeviceContext>(); |
| 57 | + |
| 58 | + const int post_nms_topN = ctx.Attr<int>("post_nms_topN"); |
| 59 | + |
| 60 | + // concat inputs along axis = 0 |
| 61 | + int roi_offset = 0; |
| 62 | + int score_offset = 0; |
| 63 | + int total_roi_num = 0; |
| 64 | + for (size_t i = 0; i < roi_ins.size(); ++i) { |
| 65 | + total_roi_num += roi_ins[i]->dims()[0]; |
| 66 | + } |
| 67 | + |
| 68 | + int real_post_num = min(post_nms_topN, total_roi_num); |
| 69 | + fpn_rois->mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); |
| 70 | + Tensor concat_rois; |
| 71 | + Tensor concat_scores; |
| 72 | + T* concat_rois_data = concat_rois.mutable_data<T>( |
| 73 | + {total_roi_num, kBBoxSize}, dev_ctx.GetPlace()); |
| 74 | + T* concat_scores_data = |
| 75 | + concat_scores.mutable_data<T>({total_roi_num, 1}, dev_ctx.GetPlace()); |
| 76 | + Tensor roi_batch_id_list; |
| 77 | + roi_batch_id_list.Resize({total_roi_num}); |
| 78 | + int* roi_batch_id_data = |
| 79 | + roi_batch_id_list.mutable_data<int>(platform::CPUPlace()); |
| 80 | + int index = 0; |
| 81 | + int lod_size; |
| 82 | + auto place = boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()); |
| 83 | + |
| 84 | + for (size_t i = 0; i < roi_ins.size(); ++i) { |
| 85 | + auto roi_in = roi_ins[i]; |
| 86 | + auto score_in = score_ins[i]; |
| 87 | + auto roi_lod = roi_in->lod().back(); |
| 88 | + lod_size = roi_lod.size() - 1; |
| 89 | + for (size_t n = 0; n < lod_size; ++n) { |
| 90 | + for (size_t j = roi_lod[n]; j < roi_lod[n + 1]; ++j) { |
| 91 | + roi_batch_id_data[index++] = n; |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + memory::Copy(place, concat_rois_data + roi_offset, place, |
| 96 | + roi_in->data<T>(), roi_in->numel() * sizeof(T), |
| 97 | + dev_ctx.stream()); |
| 98 | + memory::Copy(place, concat_scores_data + score_offset, place, |
| 99 | + score_in->data<T>(), score_in->numel() * sizeof(T), |
| 100 | + dev_ctx.stream()); |
| 101 | + roi_offset += roi_in->numel(); |
| 102 | + score_offset += score_in->numel(); |
| 103 | + } |
| 104 | + |
| 105 | + // copy batch id list to GPU |
| 106 | + Tensor roi_batch_id_list_gpu; |
| 107 | + framework::TensorCopy(roi_batch_id_list, dev_ctx.GetPlace(), |
| 108 | + &roi_batch_id_list_gpu); |
| 109 | + |
| 110 | + Tensor index_in_t; |
| 111 | + int* idx_in = |
| 112 | + index_in_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace()); |
| 113 | + platform::ForRange<platform::CUDADeviceContext> for_range_total( |
| 114 | + dev_ctx, total_roi_num); |
| 115 | + for_range_total(RangeInitFunctor{0, 1, idx_in}); |
| 116 | + |
| 117 | + Tensor keys_out_t; |
| 118 | + T* keys_out = |
| 119 | + keys_out_t.mutable_data<T>({total_roi_num}, dev_ctx.GetPlace()); |
| 120 | + Tensor index_out_t; |
| 121 | + int* idx_out = |
| 122 | + index_out_t.mutable_data<int>({total_roi_num}, dev_ctx.GetPlace()); |
| 123 | + |
| 124 | + // Determine temporary device storage requirements |
| 125 | + size_t temp_storage_bytes = 0; |
| 126 | + cub::DeviceRadixSort::SortPairsDescending<T, int>( |
| 127 | + nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in, |
| 128 | + idx_out, total_roi_num); |
| 129 | + // Allocate temporary storage |
| 130 | + auto d_temp_storage = memory::Alloc(place, temp_storage_bytes, |
| 131 | + memory::Allocator::kScratchpad); |
| 132 | + |
| 133 | + // Run sorting operation |
| 134 | + // sort score to get corresponding index |
| 135 | + cub::DeviceRadixSort::SortPairsDescending<T, int>( |
| 136 | + d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(), |
| 137 | + keys_out, idx_in, idx_out, total_roi_num); |
| 138 | + index_out_t.Resize({real_post_num}); |
| 139 | + Tensor sorted_rois; |
| 140 | + sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); |
| 141 | + Tensor sorted_batch_id; |
| 142 | + sorted_batch_id.mutable_data<int>({real_post_num}, dev_ctx.GetPlace()); |
| 143 | + GPUGather<T>(dev_ctx, concat_rois, index_out_t, &sorted_rois); |
| 144 | + GPUGather<int>(dev_ctx, roi_batch_id_list_gpu, index_out_t, |
| 145 | + &sorted_batch_id); |
| 146 | + |
| 147 | + Tensor batch_index_t; |
| 148 | + int* batch_idx_in = |
| 149 | + batch_index_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace()); |
| 150 | + platform::ForRange<platform::CUDADeviceContext> for_range_post( |
| 151 | + dev_ctx, real_post_num); |
| 152 | + for_range_post(RangeInitFunctor{0, 1, batch_idx_in}); |
| 153 | + |
| 154 | + Tensor out_id_t; |
| 155 | + int* out_id_data = |
| 156 | + out_id_t.mutable_data<int>({real_post_num}, dev_ctx.GetPlace()); |
| 157 | + // Determine temporary device storage requirements |
| 158 | + temp_storage_bytes = 0; |
| 159 | + cub::DeviceRadixSort::SortPairs<int, int>( |
| 160 | + nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data, |
| 161 | + batch_idx_in, index_out_t.data<int>(), real_post_num); |
| 162 | + // Allocate temporary storage |
| 163 | + d_temp_storage = memory::Alloc(place, temp_storage_bytes, |
| 164 | + memory::Allocator::kScratchpad); |
| 165 | + |
| 166 | + // Run sorting operation |
| 167 | + // sort batch_id to get corresponding index |
| 168 | + cub::DeviceRadixSort::SortPairs<int, int>( |
| 169 | + d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(), |
| 170 | + out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num); |
| 171 | + |
| 172 | + GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois); |
| 173 | + |
| 174 | + Tensor length_lod; |
| 175 | + int* length_lod_data = |
| 176 | + length_lod.mutable_data<int>({lod_size}, dev_ctx.GetPlace()); |
| 177 | + math::SetConstant<platform::CUDADeviceContext, int> set_zero; |
| 178 | + set_zero(dev_ctx, &length_lod, static_cast<int>(0)); |
| 179 | + |
| 180 | + int blocks = NumBlocks(real_post_num); |
| 181 | + int threads = kNumCUDAThreads; |
| 182 | + |
| 183 | + // get length-based lod by batch ids |
| 184 | + GetLengthLoD<<<blocks, threads>>>(real_post_num, out_id_data, |
| 185 | + length_lod_data); |
| 186 | + std::vector<int> length_lod_cpu(lod_size); |
| 187 | + memory::Copy(platform::CPUPlace(), length_lod_cpu.data(), place, |
| 188 | + length_lod_data, sizeof(int) * lod_size, dev_ctx.stream()); |
| 189 | + dev_ctx.Wait(); |
| 190 | + |
| 191 | + std::vector<size_t> offset(1, 0); |
| 192 | + for (int i = 0; i < lod_size; ++i) { |
| 193 | + offset.emplace_back(offset.back() + length_lod_cpu[i]); |
| 194 | + } |
| 195 | + |
| 196 | + framework::LoD lod; |
| 197 | + lod.emplace_back(offset); |
| 198 | + fpn_rois->set_lod(lod); |
| 199 | + } |
| 200 | +}; |
| 201 | + |
| 202 | +} // namespace operators |
| 203 | +} // namespace paddle |
| 204 | + |
| 205 | +namespace ops = paddle::operators; |
| 206 | +REGISTER_OP_CUDA_KERNEL( |
| 207 | + collect_fpn_proposals, |
| 208 | + ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext, |
| 209 | + float>, |
| 210 | + ops::GPUCollectFpnProposalsOpKernel<paddle::platform::CUDADeviceContext, |
| 211 | + double>); |
0 commit comments