|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/prior_box_op.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +template <typename T> |
| 21 | +__device__ inline T clip(T in) { |
| 22 | + return min(max(in, 0.), 1.); |
| 23 | +} |
| 24 | + |
| 25 | +template <typename T> |
| 26 | +__global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, |
| 27 | + const int width, const int im_height, |
| 28 | + const int im_width, const int as_num, |
| 29 | + const T offset, const T step_width, |
| 30 | + const T step_height, const T* min_sizes, |
| 31 | + const T* max_sizes, const int min_num, |
| 32 | + bool is_clip) { |
| 33 | + int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; |
| 34 | + int box_num = height * width * num_priors; |
| 35 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; |
| 36 | + i += blockDim.x * gridDim.x) { |
| 37 | + int h = i / (num_priors * width); |
| 38 | + int w = (i / num_priors) % width; |
| 39 | + int p = i % num_priors; |
| 40 | + int m = max_sizes ? p / (as_num + 1) : p / as_num; |
| 41 | + T cx = (w + offset) * step_width; |
| 42 | + T cy = (h + offset) * step_height; |
| 43 | + T bw, bh; |
| 44 | + T min_size = min_sizes[m]; |
| 45 | + if (max_sizes) { |
| 46 | + int s = p % (as_num + 1); |
| 47 | + if (s < as_num) { |
| 48 | + T ar = aspect_ratios[s]; |
| 49 | + bw = min_size * sqrt(ar) / 2.; |
| 50 | + bh = min_size / sqrt(ar) / 2.; |
| 51 | + } else { |
| 52 | + T max_size = max_sizes[m]; |
| 53 | + bw = sqrt(min_size * max_size) / 2.; |
| 54 | + bh = bw; |
| 55 | + } |
| 56 | + } else { |
| 57 | + int s = p % as_num; |
| 58 | + T ar = aspect_ratios[s]; |
| 59 | + bw = min_size * sqrt(ar) / 2.; |
| 60 | + bh = min_size / sqrt(ar) / 2.; |
| 61 | + } |
| 62 | + T xmin = (cx - bw) / im_width; |
| 63 | + T ymin = (cy - bh) / im_height; |
| 64 | + T xmax = (cx + bw) / im_width; |
| 65 | + T ymax = (cy + bh) / im_height; |
| 66 | + out[i * 4] = is_clip ? clip<T>(xmin) : xmin; |
| 67 | + out[i * 4 + 1] = is_clip ? clip<T>(ymin) : ymin; |
| 68 | + out[i * 4 + 2] = is_clip ? clip<T>(xmax) : xmax; |
| 69 | + out[i * 4 + 3] = is_clip ? clip<T>(ymax) : ymax; |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +template <typename T> |
| 74 | +__global__ void SetVariance(T* out, const T* var, const int vnum, |
| 75 | + const int num) { |
| 76 | + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; |
| 77 | + i += blockDim.x * gridDim.x) { |
| 78 | + out[i] = var[i % vnum]; |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +template <typename T> |
| 83 | +class PriorBoxOpCUDAKernel : public framework::OpKernel<T> { |
| 84 | + public: |
| 85 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 86 | + auto* input = ctx.Input<paddle::framework::Tensor>("Input"); |
| 87 | + auto* image = ctx.Input<paddle::framework::Tensor>("Image"); |
| 88 | + auto* boxes = ctx.Output<paddle::framework::Tensor>("Boxes"); |
| 89 | + auto* vars = ctx.Output<paddle::framework::Tensor>("Variances"); |
| 90 | + |
| 91 | + auto min_sizes = ctx.Attr<std::vector<float>>("min_sizes"); |
| 92 | + auto max_sizes = ctx.Attr<std::vector<float>>("max_sizes"); |
| 93 | + auto input_aspect_ratio = ctx.Attr<std::vector<float>>("aspect_ratios"); |
| 94 | + auto variances = ctx.Attr<std::vector<float>>("variances"); |
| 95 | + auto flip = ctx.Attr<bool>("flip"); |
| 96 | + auto clip = ctx.Attr<bool>("clip"); |
| 97 | + |
| 98 | + std::vector<float> aspect_ratios; |
| 99 | + ExpandAspectRatios(input_aspect_ratio, flip, aspect_ratios); |
| 100 | + |
| 101 | + T step_w = static_cast<T>(ctx.Attr<float>("step_w")); |
| 102 | + T step_h = static_cast<T>(ctx.Attr<float>("step_h")); |
| 103 | + T offset = static_cast<T>(ctx.Attr<float>("offset")); |
| 104 | + |
| 105 | + auto im_width = image->dims()[3]; |
| 106 | + auto im_height = image->dims()[2]; |
| 107 | + |
| 108 | + auto width = input->dims()[3]; |
| 109 | + auto height = input->dims()[2]; |
| 110 | + |
| 111 | + T step_width, step_height; |
| 112 | + if (step_w == 0 || step_h == 0) { |
| 113 | + step_width = static_cast<T>(im_width) / width; |
| 114 | + step_height = static_cast<T>(im_height) / height; |
| 115 | + } else { |
| 116 | + step_width = step_w; |
| 117 | + step_height = step_h; |
| 118 | + } |
| 119 | + |
| 120 | + int num_priors = aspect_ratios.size() * min_sizes.size(); |
| 121 | + if (max_sizes.size() > 0) { |
| 122 | + num_priors += max_sizes.size(); |
| 123 | + } |
| 124 | + int min_num = static_cast<int>(min_sizes.size()); |
| 125 | + int box_num = width * height * num_priors; |
| 126 | + |
| 127 | + int block = 512; |
| 128 | + int grid = (box_num + block - 1) / block; |
| 129 | + |
| 130 | + auto stream = |
| 131 | + ctx.template device_context<platform::CUDADeviceContext>().stream(); |
| 132 | + |
| 133 | + boxes->mutable_data<T>(ctx.GetPlace()); |
| 134 | + vars->mutable_data<T>(ctx.GetPlace()); |
| 135 | + |
| 136 | + framework::Tensor r; |
| 137 | + framework::TensorFromVector(aspect_ratios, ctx.device_context(), &r); |
| 138 | + |
| 139 | + framework::Tensor min; |
| 140 | + framework::TensorFromVector(min_sizes, ctx.device_context(), &min); |
| 141 | + |
| 142 | + T* max_data = nullptr; |
| 143 | + framework::Tensor max; |
| 144 | + if (max_sizes.size() > 0) { |
| 145 | + framework::TensorFromVector(max_sizes, ctx.device_context(), &max); |
| 146 | + max_data = max.data<T>(); |
| 147 | + } |
| 148 | + |
| 149 | + GenPriorBox<T><<<grid, block, 0, stream>>>( |
| 150 | + boxes->data<T>(), r.data<T>(), height, width, im_height, im_width, |
| 151 | + aspect_ratios.size(), offset, step_width, step_height, min.data<T>(), |
| 152 | + max_data, min_num, clip); |
| 153 | + |
| 154 | + framework::Tensor v; |
| 155 | + framework::TensorFromVector(variances, ctx.device_context(), &v); |
| 156 | + grid = (box_num * 4 + block - 1) / block; |
| 157 | + SetVariance<T><<<grid, block, 0, stream>>>(vars->data<T>(), v.data<T>(), |
| 158 | + variances.size(), box_num * 4); |
| 159 | + } |
| 160 | +}; // namespace operators |
| 161 | + |
| 162 | +} // namespace operators |
| 163 | +} // namespace paddle |
| 164 | + |
| 165 | +namespace ops = paddle::operators; |
| 166 | +REGISTER_OP_CUDA_KERNEL(prior_box, ops::PriorBoxOpCUDAKernel<float>, |
| 167 | + ops::PriorBoxOpCUDAKernel<double>); |
0 commit comments