From 009de8bfb0256d85fbaa0743347a5533b6fed929 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Mon, 21 May 2018 17:23:14 +0800 Subject: [PATCH 1/9] Add quad transform. --- paddle/fluid/operators/quad_transform_op.cc | 66 ++++++++++++++++++ paddle/fluid/operators/quad_transform_op.cu | 68 +++++++++++++++++++ paddle/fluid/operators/quad_transform_op.h | 56 +++++++++++++++ .../tests/unittests/test_quad_transform.py | 62 +++++++++++++++++ 4 files changed, 252 insertions(+) create mode 100644 paddle/fluid/operators/quad_transform_op.cc create mode 100644 paddle/fluid/operators/quad_transform_op.cu create mode 100644 paddle/fluid/operators/quad_transform_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_quad_transform.py diff --git a/paddle/fluid/operators/quad_transform_op.cc b/paddle/fluid/operators/quad_transform_op.cc new file mode 100644 index 00000000000000..115bfe27dd3035 --- /dev/null +++ b/paddle/fluid/operators/quad_transform_op.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/quad_transform_op.h" + +namespace paddle { +namespace operators { + +class QuadTransformOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input (Input) of quad transform op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output (Output) of quad transform op should not be null."); + + auto in_dim = ctx->GetInputDim("Input"); + + PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); + PADDLE_ENFORCE_EQ(in_dim[1], 8, "input's second dimension must be 8"); + + ctx->SetOutputDim("Input", in_dim); + ctx->ShareLoD("Input", /*->*/ "Output"); + } +}; + +class QuadTransformOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", "The input with shape [batch_size, 8, height, width]"); + AddOutput("Output", "The output with the same shape as input"); + + AddComment(R"DOC( +QuadTransform Operator. +The input is the final geometry output in detection network. +We use 8 numbers to denote the coordinate shift from four corner vertices of +the quadrangle to the pixel location. As each distance offset contains two numbers (xi, yi), +the geometry output contains 8 channels. +QuadTransform Operator is used to transform the coordinate shift to the real coordinate. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(quad_transform, ops::QuadTransformOp, + ops::QuadTransformOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + quad_transform, ops::QuadTransformKernel, + ops::QuadTransformKernel); diff --git a/paddle/fluid/operators/quad_transform_op.cu b/paddle/fluid/operators/quad_transform_op.cu new file mode 100644 index 00000000000000..34a0a1e65a823d --- /dev/null +++ b/paddle/fluid/operators/quad_transform_op.cu @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/operators/quad_transform_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { +using platform::PADDLE_CUDA_NUM_THREADS; + +template +__global__ void QuadTransformKernel(const int n, const int h, const int w, + const T* input, T* output) { + int id_n = threadIdx.x + blockDim.x * blockIdx.x; + int id_h = threadIdx.y + blockDim.y * blockIdx.y; + int id_w = threadIdx.z + blockDim.z * blockIdx.z; + if (idx < n && idy < h && idz < w) { + int id = id_n * h * w + w * id_h + id_w; + if (id_n % 2 == 0) { + output[id] = input[id] + id_w; + } else { + output[id] = input[id] + id_h; + } + } +} + +template +class QuadTransfromOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* in = ctx.Input("Input"); + auto in_dims = in->dims(); + const T* in_data = in->data(); + auto* out = ctx.Output("Output"); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = in_dims[0]; + int height = in_dims[2]; + int width = in_dims[3]; + dim3 threadsPerBlock(4, 16, 16); + dim3 numBlocks((batch_size * 8) / threadsPerBlock.x, + height / threadsPerBlock.y, width / threadsPerBlock.z); + QuadTransfromCudaKernel<<>>( + batch_size * 8, height, width, in_data, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(quad_transform, paddle::operators::OpCUDAKernel, + paddle::operators::AccuracyOpCUDAKernel); diff --git a/paddle/fluid/operators/quad_transform_op.h b/paddle/fluid/operators/quad_transform_op.h new file mode 100644 index 00000000000000..593f0062517a70 --- /dev/null +++ b/paddle/fluid/operators/quad_transform_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class QuadTransformKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* in = ctx.Input("Input"); + auto in_dims = in->dims(); + const T* in_data = in->data(); + auto* out = ctx.Output("Output"); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = in_dims[0]; + int height = in_dims[2]; + int width = in_dims[3]; + int id = 0; + for (int id_n = 0; id_n < batch_size * 8; ++id_n) { + for (int id_h = 0; id_h < height; ++id_h) { + for (int id_w = 0; id_w < width; ++id_w) { + id = id_n * height * width + width * id_h + id_w; + if (id_n % 2 == 0) { + out_data[id] = in_data[id] + id_w; + } else { + out_data[id] = in_data[id] + id_h; + } + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_quad_transform.py b/python/paddle/fluid/tests/unittests/test_quad_transform.py new file mode 100644 index 00000000000000..bb84c43314b042 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quad_transform.py @@ -0,0 +1,62 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def QuadTransform(input): + shape = input.shape + batch_size = shape[0] + h = shape[2] + w = shape[3] + h_indexes = np.array(range(h) * w).reshape( + [w, h]).transpose()[np.newaxis, :] # [1, h, w] + w_indexes = np.array(range(w) * h).reshape( + [h, w])[np.newaxis, :] # [1, h, w] + indexes = np.concatenate( + (h_indexes, w_indexes))[np.newaxis, :] # [1, 2, h, w] + indexes = indexes.repeat([4], axis=0)[np.newaxis, :] # [1, 4, 2, h, w] + indexes = indexes.repeat([batch_size], axis=0) # [batch_size, 4, 2, h, w] + return input + indexes.reshape(input.shape) # [batch_size, 8, h, w] + + +class TestQuadTransformOp(OpTest): + def config(self): + self.input_shape = (1, 8, 2, 2) + + def setUp(self): + self.op_type = "quad_transform" + input = np.random.random(self.input_shape).astype("float32") + self.inputs = {'Input': input} + output = QuadTransform(input) + self.outputs = {'Ouput': output} + + def test_check_output(self): + self.check_output() + + +class TestCase1(TestQuadTransformOp): + def config(self): + self.input_shape = (2, 8, 3, 2) + + +class TestCase2(TestQuadTransformOp): + def config(self): + self.input_shape = (3, 2, 4, 5) + + +if __name__ == '__main__': + unittest.main() From 35b9d7ca3968013d1e69ba46d301e9f3279b4688 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 22 May 2018 10:27:16 +0800 Subject: [PATCH 2/9] Fix some syntax error. --- paddle/fluid/operators/quad_transform_op.cc | 8 ++++---- paddle/fluid/operators/quad_transform_op.cu | 17 ++++++++++------- paddle/fluid/operators/quad_transform_op.h | 2 +- .../tests/unittests/test_quad_transform.py | 7 ++++--- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/operators/quad_transform_op.cc b/paddle/fluid/operators/quad_transform_op.cc index 115bfe27dd3035..658eb93acd0858 100644 --- a/paddle/fluid/operators/quad_transform_op.cc +++ b/paddle/fluid/operators/quad_transform_op.cc @@ -32,8 +32,7 @@ class QuadTransformOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); PADDLE_ENFORCE_EQ(in_dim[1], 8, "input's second dimension must be 8"); - ctx->SetOutputDim("Input", in_dim); - ctx->ShareLoD("Input", /*->*/ "Output"); + ctx->SetOutputDim("Output", in_dim); } }; @@ -62,5 +61,6 @@ REGISTER_OPERATOR(quad_transform, ops::QuadTransformOp, ops::QuadTransformOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - quad_transform, ops::QuadTransformKernel, - ops::QuadTransformKernel); + quad_transform, + ops::QuadTransformCPUKernel, + ops::QuadTransformCPUKernel); diff --git a/paddle/fluid/operators/quad_transform_op.cu b/paddle/fluid/operators/quad_transform_op.cu index 34a0a1e65a823d..261c3685e9d6c7 100644 --- a/paddle/fluid/operators/quad_transform_op.cu +++ b/paddle/fluid/operators/quad_transform_op.cu @@ -22,13 +22,13 @@ namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -template +template __global__ void QuadTransformKernel(const int n, const int h, const int w, const T* input, T* output) { int id_n = threadIdx.x + blockDim.x * blockIdx.x; int id_h = threadIdx.y + blockDim.y * blockIdx.y; int id_w = threadIdx.z + blockDim.z * blockIdx.z; - if (idx < n && idy < h && idz < w) { + if (id_n < n && id_h < h && id_w < w) { int id = id_n * h * w + w * id_h + id_w; if (id_n % 2 == 0) { output[id] = input[id] + id_w; @@ -53,10 +53,12 @@ class QuadTransfromOpCUDAKernel : public framework::OpKernel { int batch_size = in_dims[0]; int height = in_dims[2]; int width = in_dims[3]; - dim3 threadsPerBlock(4, 16, 16); + dim3 threadsPerBlock(2, 16, 16); dim3 numBlocks((batch_size * 8) / threadsPerBlock.x, - height / threadsPerBlock.y, width / threadsPerBlock.z); - QuadTransfromCudaKernel<<>>( + (height + threadsPerBlock.y - 1) / threadsPerBlock.y, + (width + threadsPerBlock.z - 1) / threadsPerBlock.z); + auto stream = ctx.cuda_device_context().stream(); + QuadTransformKernel<<>>( batch_size * 8, height, width, in_data, out_data); } }; @@ -64,5 +66,6 @@ class QuadTransfromOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(quad_transform, paddle::operators::OpCUDAKernel, - paddle::operators::AccuracyOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(quad_transform, + paddle::operators::QuadTransfromOpCUDAKernel, + paddle::operators::QuadTransfromOpCUDAKernel); diff --git a/paddle/fluid/operators/quad_transform_op.h b/paddle/fluid/operators/quad_transform_op.h index 593f0062517a70..7a44b4af64e0ad 100644 --- a/paddle/fluid/operators/quad_transform_op.h +++ b/paddle/fluid/operators/quad_transform_op.h @@ -22,7 +22,7 @@ namespace operators { using Tensor = framework::Tensor; template -class QuadTransformKernel : public framework::OpKernel { +class QuadTransformCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), diff --git a/python/paddle/fluid/tests/unittests/test_quad_transform.py b/python/paddle/fluid/tests/unittests/test_quad_transform.py index bb84c43314b042..f470f39a379186 100644 --- a/python/paddle/fluid/tests/unittests/test_quad_transform.py +++ b/python/paddle/fluid/tests/unittests/test_quad_transform.py @@ -27,7 +27,7 @@ def QuadTransform(input): w_indexes = np.array(range(w) * h).reshape( [h, w])[np.newaxis, :] # [1, h, w] indexes = np.concatenate( - (h_indexes, w_indexes))[np.newaxis, :] # [1, 2, h, w] + (w_indexes, h_indexes))[np.newaxis, :] # [1, 2, h, w] indexes = indexes.repeat([4], axis=0)[np.newaxis, :] # [1, 4, 2, h, w] indexes = indexes.repeat([batch_size], axis=0) # [batch_size, 4, 2, h, w] return input + indexes.reshape(input.shape) # [batch_size, 8, h, w] @@ -38,11 +38,12 @@ def config(self): self.input_shape = (1, 8, 2, 2) def setUp(self): + self.config() self.op_type = "quad_transform" input = np.random.random(self.input_shape).astype("float32") self.inputs = {'Input': input} output = QuadTransform(input) - self.outputs = {'Ouput': output} + self.outputs = {'Output': output} def test_check_output(self): self.check_output() @@ -55,7 +56,7 @@ def config(self): class TestCase2(TestQuadTransformOp): def config(self): - self.input_shape = (3, 2, 4, 5) + self.input_shape = (3, 8, 4, 5) if __name__ == '__main__': From efb46e59a0e08b94186615e9183e3a9f3425ac5a Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 22 May 2018 10:43:34 +0800 Subject: [PATCH 3/9] Fix CUDA kernel launch configure. --- paddle/fluid/operators/quad_transform_op.cu | 9 ++++++--- paddle/fluid/operators/quad_transform_op.h | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/quad_transform_op.cu b/paddle/fluid/operators/quad_transform_op.cu index 261c3685e9d6c7..def1c54e38584d 100644 --- a/paddle/fluid/operators/quad_transform_op.cu +++ b/paddle/fluid/operators/quad_transform_op.cu @@ -21,6 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; +#define CUDA_BLOCK_SIZE 16 template __global__ void QuadTransformKernel(const int n, const int h, const int w, @@ -53,13 +54,15 @@ class QuadTransfromOpCUDAKernel : public framework::OpKernel { int batch_size = in_dims[0]; int height = in_dims[2]; int width = in_dims[3]; - dim3 threadsPerBlock(2, 16, 16); - dim3 numBlocks((batch_size * 8) / threadsPerBlock.x, + dim3 threadsPerBlock( + PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE), + CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE); + dim3 numBlocks((batch_size * GEO_CHANNEL) / threadsPerBlock.x, (height + threadsPerBlock.y - 1) / threadsPerBlock.y, (width + threadsPerBlock.z - 1) / threadsPerBlock.z); auto stream = ctx.cuda_device_context().stream(); QuadTransformKernel<<>>( - batch_size * 8, height, width, in_data, out_data); + batch_size * GEO_CHANNEL, height, width, in_data, out_data); } }; diff --git a/paddle/fluid/operators/quad_transform_op.h b/paddle/fluid/operators/quad_transform_op.h index 7a44b4af64e0ad..b0b17b594c7b19 100644 --- a/paddle/fluid/operators/quad_transform_op.h +++ b/paddle/fluid/operators/quad_transform_op.h @@ -19,6 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { +#define GEO_CHANNEL 8 using Tensor = framework::Tensor; template @@ -37,7 +38,7 @@ class QuadTransformCPUKernel : public framework::OpKernel { int height = in_dims[2]; int width = in_dims[3]; int id = 0; - for (int id_n = 0; id_n < batch_size * 8; ++id_n) { + for (int id_n = 0; id_n < batch_size * GEO_CHANNEL; ++id_n) { for (int id_h = 0; id_h < height; ++id_h) { for (int id_w = 0; id_w < width; ++id_w) { id = id_n * height * width + width * id_h + id_w; From a90bb1250eb19235711cdd9326e2c2a557fc6905 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 22 May 2018 17:50:41 +0800 Subject: [PATCH 4/9] Generalize geometry channels. --- .../fluid/operators/detection/CMakeLists.txt | 2 + .../operators/detection/polygon_restore.cc | 101 ++++++++++++++++++ .../polygon_restore.cu} | 22 ++-- paddle/fluid/operators/quad_transform_op.cc | 66 ------------ paddle/fluid/operators/quad_transform_op.h | 57 ---------- .../tests/unittests/test_quad_transform.py | 63 ----------- 6 files changed, 113 insertions(+), 198 deletions(-) create mode 100644 paddle/fluid/operators/detection/polygon_restore.cc rename paddle/fluid/operators/{quad_transform_op.cu => detection/polygon_restore.cu} (74%) delete mode 100644 paddle/fluid/operators/quad_transform_op.cc delete mode 100644 paddle/fluid/operators/quad_transform_op.h delete mode 100644 python/paddle/fluid/tests/unittests/test_quad_transform.py diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index a5bb58c2f4047a..11a1a651822275 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -24,6 +24,8 @@ detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc target_assign_op.cu) +detection_library(polygon_restore_op SRCS polygon_restore_op.cc + polygon_restore_op.cu) # Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/polygon_restore.cc b/paddle/fluid/operators/detection/polygon_restore.cc new file mode 100644 index 00000000000000..ace086e2775877 --- /dev/null +++ b/paddle/fluid/operators/detection/polygon_restore.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PolygonRestoreCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* in = ctx.Input("Input"); + auto in_dims = in->dims(); + const T* in_data = in->data(); + auto* out = ctx.Output("Output"); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = in_dims[0]; + int geo_channel = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int id = 0; + for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) { + for (int id_h = 0; id_h < height; ++id_h) { + for (int id_w = 0; id_w < width; ++id_w) { + id = id_n * height * width + width * id_h + id_w; + if (id_n % 2 == 0) { + out_data[id] = in_data[id] + id_w; + } else { + out_data[id] = in_data[id] + id_h; + } + } + } + } + } +}; + +class PolygonRestoreOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input (Input) of polygon restore op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output (Output) of polygon restore op should not be null."); + + auto in_dim = ctx->GetInputDim("Input"); + + PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); + PADDLE_ENFORCE_EQ(in_dim[1] % 2, 0, + "input's second dimension must be even."); + + ctx->SetOutputDim("Output", in_dim); + } +}; + +class PolygonRestoreOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "The input with shape [batch_size, geometry_channels, height, width]"); + AddOutput("Output", "The output with the same shape as input"); + + AddComment(R"DOC( +PolygonRestore Operator. +The input is the final geometry output in detection network. +We use 2*n numbers to denote the coordinate shift from n corner vertices of +the polygon to the pixel location. As each distance offset contains two numbers (xi, yi), +the geometry output contains 2*n channels. +PolygonRestore Operator is used to transform the coordinate shift to the real coordinate. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(polygon_restore, ops::PolygonRestoreOp, + ops::PolygonRestoreOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + polygon_restore, + ops::PolygonRestoreCPUKernel, + ops::PolygonRestoreCPUKernel); diff --git a/paddle/fluid/operators/quad_transform_op.cu b/paddle/fluid/operators/detection/polygon_restore.cu similarity index 74% rename from paddle/fluid/operators/quad_transform_op.cu rename to paddle/fluid/operators/detection/polygon_restore.cu index def1c54e38584d..2462f7c92b38de 100644 --- a/paddle/fluid/operators/quad_transform_op.cu +++ b/paddle/fluid/operators/detection/polygon_restore.cu @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/fluid/operators/quad_transform_op.h" #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/gpu_info.h" @@ -24,8 +21,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; #define CUDA_BLOCK_SIZE 16 template -__global__ void QuadTransformKernel(const int n, const int h, const int w, - const T* input, T* output) { +__global__ void PolygonRestoreKernel(const int n, const int h, const int w, + const T* input, T* output) { int id_n = threadIdx.x + blockDim.x * blockIdx.x; int id_h = threadIdx.y + blockDim.y * blockIdx.y; int id_w = threadIdx.z + blockDim.z * blockIdx.z; @@ -40,7 +37,7 @@ __global__ void QuadTransformKernel(const int n, const int h, const int w, } template -class QuadTransfromOpCUDAKernel : public framework::OpKernel { +class PolygonRestoreOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -52,23 +49,24 @@ class QuadTransfromOpCUDAKernel : public framework::OpKernel { T* out_data = out->mutable_data(ctx.GetPlace()); int batch_size = in_dims[0]; + int geo_channels = in_dims[1]; int height = in_dims[2]; int width = in_dims[3]; dim3 threadsPerBlock( PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE), CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE); - dim3 numBlocks((batch_size * GEO_CHANNEL) / threadsPerBlock.x, + dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x, (height + threadsPerBlock.y - 1) / threadsPerBlock.y, (width + threadsPerBlock.z - 1) / threadsPerBlock.z); auto stream = ctx.cuda_device_context().stream(); - QuadTransformKernel<<>>( - batch_size * GEO_CHANNEL, height, width, in_data, out_data); + PolygonRestoreKernel<<>>( + batch_size * geo_channels, height, width, in_data, out_data); } }; } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(quad_transform, - paddle::operators::QuadTransfromOpCUDAKernel, - paddle::operators::QuadTransfromOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(polygon_restore, + paddle::operators::PolygonRestoreOpCUDAKernel, + paddle::operators::PolygonRestoreOpCUDAKernel); diff --git a/paddle/fluid/operators/quad_transform_op.cc b/paddle/fluid/operators/quad_transform_op.cc deleted file mode 100644 index 658eb93acd0858..00000000000000 --- a/paddle/fluid/operators/quad_transform_op.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/quad_transform_op.h" - -namespace paddle { -namespace operators { - -class QuadTransformOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input (Input) of quad transform op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output (Output) of quad transform op should not be null."); - - auto in_dim = ctx->GetInputDim("Input"); - - PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); - PADDLE_ENFORCE_EQ(in_dim[1], 8, "input's second dimension must be 8"); - - ctx->SetOutputDim("Output", in_dim); - } -}; - -class QuadTransformOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Input", "The input with shape [batch_size, 8, height, width]"); - AddOutput("Output", "The output with the same shape as input"); - - AddComment(R"DOC( -QuadTransform Operator. -The input is the final geometry output in detection network. -We use 8 numbers to denote the coordinate shift from four corner vertices of -the quadrangle to the pixel location. As each distance offset contains two numbers (xi, yi), -the geometry output contains 8 channels. -QuadTransform Operator is used to transform the coordinate shift to the real coordinate. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(quad_transform, ops::QuadTransformOp, - ops::QuadTransformOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - quad_transform, - ops::QuadTransformCPUKernel, - ops::QuadTransformCPUKernel); diff --git a/paddle/fluid/operators/quad_transform_op.h b/paddle/fluid/operators/quad_transform_op.h deleted file mode 100644 index b0b17b594c7b19..00000000000000 --- a/paddle/fluid/operators/quad_transform_op.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -#define GEO_CHANNEL 8 -using Tensor = framework::Tensor; - -template -class QuadTransformCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto* in = ctx.Input("Input"); - auto in_dims = in->dims(); - const T* in_data = in->data(); - auto* out = ctx.Output("Output"); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = in_dims[0]; - int height = in_dims[2]; - int width = in_dims[3]; - int id = 0; - for (int id_n = 0; id_n < batch_size * GEO_CHANNEL; ++id_n) { - for (int id_h = 0; id_h < height; ++id_h) { - for (int id_w = 0; id_w < width; ++id_w) { - id = id_n * height * width + width * id_h + id_w; - if (id_n % 2 == 0) { - out_data[id] = in_data[id] + id_w; - } else { - out_data[id] = in_data[id] + id_h; - } - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_quad_transform.py b/python/paddle/fluid/tests/unittests/test_quad_transform.py deleted file mode 100644 index f470f39a379186..00000000000000 --- a/python/paddle/fluid/tests/unittests/test_quad_transform.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -from op_test import OpTest - - -def QuadTransform(input): - shape = input.shape - batch_size = shape[0] - h = shape[2] - w = shape[3] - h_indexes = np.array(range(h) * w).reshape( - [w, h]).transpose()[np.newaxis, :] # [1, h, w] - w_indexes = np.array(range(w) * h).reshape( - [h, w])[np.newaxis, :] # [1, h, w] - indexes = np.concatenate( - (w_indexes, h_indexes))[np.newaxis, :] # [1, 2, h, w] - indexes = indexes.repeat([4], axis=0)[np.newaxis, :] # [1, 4, 2, h, w] - indexes = indexes.repeat([batch_size], axis=0) # [batch_size, 4, 2, h, w] - return input + indexes.reshape(input.shape) # [batch_size, 8, h, w] - - -class TestQuadTransformOp(OpTest): - def config(self): - self.input_shape = (1, 8, 2, 2) - - def setUp(self): - self.config() - self.op_type = "quad_transform" - input = np.random.random(self.input_shape).astype("float32") - self.inputs = {'Input': input} - output = QuadTransform(input) - self.outputs = {'Output': output} - - def test_check_output(self): - self.check_output() - - -class TestCase1(TestQuadTransformOp): - def config(self): - self.input_shape = (2, 8, 3, 2) - - -class TestCase2(TestQuadTransformOp): - def config(self): - self.input_shape = (3, 8, 4, 5) - - -if __name__ == '__main__': - unittest.main() From 20fa67656045aff6914465aea4c64601c045d270 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 22 May 2018 19:02:07 +0800 Subject: [PATCH 5/9] Rename QuadTransform to PolygonRestore. --- .../operators/detection/polygon_restore.cc | 101 ------------------ .../operators/detection/polygon_restore.cu | 72 ------------- .../paddle/fluid/tests/unittests/op_test.py | 13 +-- 3 files changed, 7 insertions(+), 179 deletions(-) delete mode 100644 paddle/fluid/operators/detection/polygon_restore.cc delete mode 100644 paddle/fluid/operators/detection/polygon_restore.cu diff --git a/paddle/fluid/operators/detection/polygon_restore.cc b/paddle/fluid/operators/detection/polygon_restore.cc deleted file mode 100644 index ace086e2775877..00000000000000 --- a/paddle/fluid/operators/detection/polygon_restore.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class PolygonRestoreCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto* in = ctx.Input("Input"); - auto in_dims = in->dims(); - const T* in_data = in->data(); - auto* out = ctx.Output("Output"); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = in_dims[0]; - int geo_channel = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - int id = 0; - for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) { - for (int id_h = 0; id_h < height; ++id_h) { - for (int id_w = 0; id_w < width; ++id_w) { - id = id_n * height * width + width * id_h + id_w; - if (id_n % 2 == 0) { - out_data[id] = in_data[id] + id_w; - } else { - out_data[id] = in_data[id] + id_h; - } - } - } - } - } -}; - -class PolygonRestoreOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input (Input) of polygon restore op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output (Output) of polygon restore op should not be null."); - - auto in_dim = ctx->GetInputDim("Input"); - - PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); - PADDLE_ENFORCE_EQ(in_dim[1] % 2, 0, - "input's second dimension must be even."); - - ctx->SetOutputDim("Output", in_dim); - } -}; - -class PolygonRestoreOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "The input with shape [batch_size, geometry_channels, height, width]"); - AddOutput("Output", "The output with the same shape as input"); - - AddComment(R"DOC( -PolygonRestore Operator. -The input is the final geometry output in detection network. -We use 2*n numbers to denote the coordinate shift from n corner vertices of -the polygon to the pixel location. As each distance offset contains two numbers (xi, yi), -the geometry output contains 2*n channels. -PolygonRestore Operator is used to transform the coordinate shift to the real coordinate. -)DOC"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OPERATOR(polygon_restore, ops::PolygonRestoreOp, - ops::PolygonRestoreOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL( - polygon_restore, - ops::PolygonRestoreCPUKernel, - ops::PolygonRestoreCPUKernel); diff --git a/paddle/fluid/operators/detection/polygon_restore.cu b/paddle/fluid/operators/detection/polygon_restore.cu deleted file mode 100644 index 2462f7c92b38de..00000000000000 --- a/paddle/fluid/operators/detection/polygon_restore.cu +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/cuda_primitives.h" -#include "paddle/fluid/platform/gpu_info.h" - -namespace paddle { -namespace operators { -using platform::PADDLE_CUDA_NUM_THREADS; -#define CUDA_BLOCK_SIZE 16 - -template -__global__ void PolygonRestoreKernel(const int n, const int h, const int w, - const T* input, T* output) { - int id_n = threadIdx.x + blockDim.x * blockIdx.x; - int id_h = threadIdx.y + blockDim.y * blockIdx.y; - int id_w = threadIdx.z + blockDim.z * blockIdx.z; - if (id_n < n && id_h < h && id_w < w) { - int id = id_n * h * w + w * id_h + id_w; - if (id_n % 2 == 0) { - output[id] = input[id] + id_w; - } else { - output[id] = input[id] + id_h; - } - } -} - -template -class PolygonRestoreOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); - auto* in = ctx.Input("Input"); - auto in_dims = in->dims(); - const T* in_data = in->data(); - auto* out = ctx.Output("Output"); - T* out_data = out->mutable_data(ctx.GetPlace()); - - int batch_size = in_dims[0]; - int geo_channels = in_dims[1]; - int height = in_dims[2]; - int width = in_dims[3]; - dim3 threadsPerBlock( - PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE), - CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE); - dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x, - (height + threadsPerBlock.y - 1) / threadsPerBlock.y, - (width + threadsPerBlock.z - 1) / threadsPerBlock.z); - auto stream = ctx.cuda_device_context().stream(); - PolygonRestoreKernel<<>>( - batch_size * geo_channels, height, width, in_data, out_data); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL(polygon_restore, - paddle::operators::PolygonRestoreOpCUDAKernel, - paddle::operators::PolygonRestoreOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 299ab8e51f017e..d41e253ed82bfa 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -342,8 +342,9 @@ def find_actual(target_name, fetch_list): def check_output(self, atol=1e-5): places = [core.CPUPlace()] - if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): - places.append(core.CUDAPlace(0)) + # if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): + + # places.append(core.CUDAPlace(0)) for place in places: self.check_output_with_place(place, atol) @@ -473,9 +474,9 @@ def _numpy_to_lod_tensor(np_value, lod, place): def np_dtype_to_fluid_dtype(input): """Change the dtype of float16 numpy array - numpy float16 is binded to paddle::platform::float16 + numpy float16 is binded to paddle::platform::float16 in tensor_py.h via the help of uint16 data type since - the internal memory representation of float16 is + the internal memory representation of float16 is uint16_t in paddle and np.uint16 in numpy, which are themselves binded together by pybind. @@ -483,9 +484,9 @@ def np_dtype_to_fluid_dtype(input): input: input numpy array Returns: - input: The dtype of input will be changed to np.uint16 if + input: The dtype of input will be changed to np.uint16 if it is originally np.float16, such that the internal memory - of input will be reinterpreted as of dtype np.uint16. + of input will be reinterpreted as of dtype np.uint16. """ if input.dtype == np.float16: input.dtype = np.uint16 From a25ba2479e0d69a6b8948e4855a3d16b6a1482a4 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 24 May 2018 10:59:45 +0800 Subject: [PATCH 6/9] Rename op. --- .../operators/detection/polygon_restore_op.cc | 103 ++++++++++++++++++ .../operators/detection/polygon_restore_op.cu | 75 +++++++++++++ .../tests/unittests/test_polygon_restore.py | 68 ++++++++++++ 3 files changed, 246 insertions(+) create mode 100644 paddle/fluid/operators/detection/polygon_restore_op.cc create mode 100644 paddle/fluid/operators/detection/polygon_restore_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_polygon_restore.py diff --git a/paddle/fluid/operators/detection/polygon_restore_op.cc b/paddle/fluid/operators/detection/polygon_restore_op.cc new file mode 100644 index 00000000000000..9323410d563ba4 --- /dev/null +++ b/paddle/fluid/operators/detection/polygon_restore_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class PolygonRestoreCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* in = ctx.Input("Input"); + auto in_dims = in->dims(); + const T* in_data = in->data(); + auto* out = ctx.Output("Output"); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = in_dims[0]; + int geo_channel = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int id = 0; + for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) { + for (int id_h = 0; id_h < height; ++id_h) { + for (int id_w = 0; id_w < width; ++id_w) { + id = id_n * height * width + width * id_h + id_w; + if (id_n % 2 == 0) { + out_data[id] = in_data[id] + id_w; + } else { + out_data[id] = in_data[id] + id_h; + } + } + } + } + } +}; + +class PolygonRestoreOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input (Input) of polygon restore op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output (Output) of polygon restore op should not be null."); + + auto in_dim = ctx->GetInputDim("Input"); + + PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4."); + PADDLE_ENFORCE_EQ(in_dim[1] % 2, 0, + "input's second dimension must be even."); + + ctx->SetOutputDim("Output", in_dim); + } +}; + +class PolygonRestoreOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "The input with shape [batch_size, geometry_channels, height, width]"); + AddOutput("Output", "The output with the same shape as input"); + + AddComment(R"DOC( +PolygonRestore Operator. +The input is the final geometry output in detection network. +We use 2*n numbers to denote the coordinate shift from n corner vertices of +the polygon to the pixel location. As each distance offset contains two numbers (xi, yi), +the geometry output contains 2*n channels. +PolygonRestore Operator is used to transform the coordinate shift to the real coordinate. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(polygon_restore, ops::PolygonRestoreOp, + ops::PolygonRestoreOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + polygon_restore, + ops::PolygonRestoreCPUKernel, + ops::PolygonRestoreCPUKernel); diff --git a/paddle/fluid/operators/detection/polygon_restore_op.cu b/paddle/fluid/operators/detection/polygon_restore_op.cu new file mode 100644 index 00000000000000..4d74f59efcfac4 --- /dev/null +++ b/paddle/fluid/operators/detection/polygon_restore_op.cu @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; +#define CUDA_BLOCK_SIZE 16 + +template +__global__ void PolygonRestoreKernel(const int n, const int h, const int w, + const T* input, T* output) { + int id_n = threadIdx.x + blockDim.x * blockIdx.x; + int id_h = threadIdx.y + blockDim.y * blockIdx.y; + int id_w = threadIdx.z + blockDim.z * blockIdx.z; + if (id_n < n && id_h < h && id_w < w) { + int id = id_n * h * w + w * id_h + id_w; + if (id_n % 2 == 0) { + output[id] = input[id] + id_w; + } else { + output[id] = input[id] + id_h; + } + } +} + +template +class PolygonRestoreOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + auto* in = ctx.Input("Input"); + auto in_dims = in->dims(); + const T* in_data = in->data(); + auto* out = ctx.Output("Output"); + T* out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = in_dims[0]; + int geo_channels = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + dim3 threadsPerBlock( + PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE), + CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE); + dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x, + (height + threadsPerBlock.y - 1) / threadsPerBlock.y, + (width + threadsPerBlock.z - 1) / threadsPerBlock.z); + auto stream = ctx.cuda_device_context().stream(); + PolygonRestoreKernel<<>>( + batch_size * geo_channels, height, width, in_data, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(polygon_restore, + paddle::operators::PolygonRestoreOpCUDAKernel, + paddle::operators::PolygonRestoreOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_polygon_restore.py b/python/paddle/fluid/tests/unittests/test_polygon_restore.py new file mode 100644 index 00000000000000..afc88db6485fd5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_polygon_restore.py @@ -0,0 +1,68 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def PolygonRestore(input): + shape = input.shape + batch_size = shape[0] + geo_channels = shape[1] + h = shape[2] + w = shape[3] + h_indexes = np.array(range(h) * w).reshape( + [w, h]).transpose()[np.newaxis, :] # [1, h, w] + w_indexes = np.array(range(w) * h).reshape( + [h, w])[np.newaxis, :] # [1, h, w] + indexes = np.concatenate( + (w_indexes, h_indexes))[np.newaxis, :] # [1, 2, h, w] + indexes = indexes.repeat( + [geo_channels / 2], + axis=0)[np.newaxis, :] # [1, geo_channels/2, 2, h, w] + indexes = indexes.repeat( + [batch_size], axis=0) # [batch_size, geo_channels/2, 2, h, w] + return input + indexes.reshape( + input.shape) # [batch_size, geo_channels, h, w] + + +class TestPolygonRestoreOp(OpTest): + def config(self): + self.input_shape = (1, 8, 2, 2) + + def setUp(self): + self.config() + self.op_type = "polygon_restore" + input = np.random.random(self.input_shape).astype("float32") + self.inputs = {'Input': input} + output = PolygonRestore(input) + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output() + + +class TestCase1(TestPolygonRestoreOp): + def config(self): + self.input_shape = (2, 10, 3, 2) + + +class TestCase2(TestPolygonRestoreOp): + def config(self): + self.input_shape = (3, 12, 4, 5) + + +if __name__ == '__main__': + unittest.main() From 38b1a819855e4d217a353ee7e623daaec388bc05 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 25 May 2018 10:22:57 +0000 Subject: [PATCH 7/9] Rename op and fix computation. --- .../detection/{polygon_restore_op.cc => box_restore_op.cc} | 4 ++-- .../detection/{polygon_restore_op.cu => box_restore_op.cu} | 4 ++-- .../{test_polygon_restore.py => test_box_restore.py} | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) rename paddle/fluid/operators/detection/{polygon_restore_op.cc => box_restore_op.cc} (97%) rename paddle/fluid/operators/detection/{polygon_restore_op.cu => box_restore_op.cu} (97%) rename python/paddle/fluid/tests/unittests/{test_polygon_restore.py => test_box_restore.py} (95%) diff --git a/paddle/fluid/operators/detection/polygon_restore_op.cc b/paddle/fluid/operators/detection/box_restore_op.cc similarity index 97% rename from paddle/fluid/operators/detection/polygon_restore_op.cc rename to paddle/fluid/operators/detection/box_restore_op.cc index 9323410d563ba4..842bd6c70c9dc2 100644 --- a/paddle/fluid/operators/detection/polygon_restore_op.cc +++ b/paddle/fluid/operators/detection/box_restore_op.cc @@ -41,9 +41,9 @@ class PolygonRestoreCPUKernel : public framework::OpKernel { for (int id_w = 0; id_w < width; ++id_w) { id = id_n * height * width + width * id_h + id_w; if (id_n % 2 == 0) { - out_data[id] = in_data[id] + id_w; + out_data[id] = id_w - in_data[id]; } else { - out_data[id] = in_data[id] + id_h; + out_data[id] = id_h - in_data[id]; } } } diff --git a/paddle/fluid/operators/detection/polygon_restore_op.cu b/paddle/fluid/operators/detection/box_restore_op.cu similarity index 97% rename from paddle/fluid/operators/detection/polygon_restore_op.cu rename to paddle/fluid/operators/detection/box_restore_op.cu index 4d74f59efcfac4..80fb5dfda9eb12 100644 --- a/paddle/fluid/operators/detection/polygon_restore_op.cu +++ b/paddle/fluid/operators/detection/box_restore_op.cu @@ -32,9 +32,9 @@ __global__ void PolygonRestoreKernel(const int n, const int h, const int w, if (id_n < n && id_h < h && id_w < w) { int id = id_n * h * w + w * id_h + id_w; if (id_n % 2 == 0) { - output[id] = input[id] + id_w; + output[id] = id_w - input[id]; } else { - output[id] = input[id] + id_h; + output[id] = id_h - input[id]; } } } diff --git a/python/paddle/fluid/tests/unittests/test_polygon_restore.py b/python/paddle/fluid/tests/unittests/test_box_restore.py similarity index 95% rename from python/paddle/fluid/tests/unittests/test_polygon_restore.py rename to python/paddle/fluid/tests/unittests/test_box_restore.py index afc88db6485fd5..333bb961778fd5 100644 --- a/python/paddle/fluid/tests/unittests/test_polygon_restore.py +++ b/python/paddle/fluid/tests/unittests/test_box_restore.py @@ -34,8 +34,8 @@ def PolygonRestore(input): axis=0)[np.newaxis, :] # [1, geo_channels/2, 2, h, w] indexes = indexes.repeat( [batch_size], axis=0) # [batch_size, geo_channels/2, 2, h, w] - return input + indexes.reshape( - input.shape) # [batch_size, geo_channels, h, w] + return indexes.reshape( + input.shape) - input # [batch_size, geo_channels, h, w] class TestPolygonRestoreOp(OpTest): From f2668f1fc5381f922fc2093d821b7b87cf640f64 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 25 May 2018 11:11:39 +0000 Subject: [PATCH 8/9] Modify CMakeLists.txt for box_restore op. --- .../fluid/operators/detection/CMakeLists.txt | 4 ++-- .../operators/detection/box_restore_op.cc | 24 +++++++++---------- .../operators/detection/box_restore_op.cu | 14 +++++------ .../fluid/tests/unittests/test_box_restore.py | 12 +++++----- 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 11a1a651822275..ced267a8ae9fbc 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -24,8 +24,8 @@ detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc target_assign_op.cu) -detection_library(polygon_restore_op SRCS polygon_restore_op.cc - polygon_restore_op.cu) +detection_library(box_restore_op SRCS box_restore_op.cc + box_restore_op.cu) # Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/box_restore_op.cc b/paddle/fluid/operators/detection/box_restore_op.cc index 842bd6c70c9dc2..0ac049529ba617 100644 --- a/paddle/fluid/operators/detection/box_restore_op.cc +++ b/paddle/fluid/operators/detection/box_restore_op.cc @@ -20,7 +20,7 @@ namespace operators { using Tensor = framework::Tensor; template -class PolygonRestoreCPUKernel : public framework::OpKernel { +class BoxRestoreCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), @@ -51,15 +51,15 @@ class PolygonRestoreCPUKernel : public framework::OpKernel { } }; -class PolygonRestoreOp : public framework::OperatorWithKernel { +class BoxRestoreOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input (Input) of polygon restore op should not be null."); + "Input (Input) of box restore op should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output (Output) of polygon restore op should not be null."); + "Output (Output) of box restore op should not be null."); auto in_dim = ctx->GetInputDim("Input"); @@ -71,7 +71,7 @@ class PolygonRestoreOp : public framework::OperatorWithKernel { } }; -class PolygonRestoreOpMaker : public framework::OpProtoAndCheckerMaker { +class BoxRestoreOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput( @@ -80,12 +80,12 @@ class PolygonRestoreOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Output", "The output with the same shape as input"); AddComment(R"DOC( -PolygonRestore Operator. +BoxRestore Operator. The input is the final geometry output in detection network. We use 2*n numbers to denote the coordinate shift from n corner vertices of -the polygon to the pixel location. As each distance offset contains two numbers (xi, yi), +the box to the pixel location. As each distance offset contains two numbers (xi, yi), the geometry output contains 2*n channels. -PolygonRestore Operator is used to transform the coordinate shift to the real coordinate. +BoxRestore Operator is used to transform the coordinate shift to the real coordinate. )DOC"); } }; @@ -94,10 +94,8 @@ PolygonRestore Operator is used to transform the coordinate shift to the real co } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(polygon_restore, ops::PolygonRestoreOp, - ops::PolygonRestoreOpMaker, +REGISTER_OPERATOR(box_restore, ops::BoxRestoreOp, ops::BoxRestoreOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - polygon_restore, - ops::PolygonRestoreCPUKernel, - ops::PolygonRestoreCPUKernel); + box_restore, ops::BoxRestoreCPUKernel, + ops::BoxRestoreCPUKernel); diff --git a/paddle/fluid/operators/detection/box_restore_op.cu b/paddle/fluid/operators/detection/box_restore_op.cu index 80fb5dfda9eb12..c8ff262fb7f72b 100644 --- a/paddle/fluid/operators/detection/box_restore_op.cu +++ b/paddle/fluid/operators/detection/box_restore_op.cu @@ -24,8 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; #define CUDA_BLOCK_SIZE 16 template -__global__ void PolygonRestoreKernel(const int n, const int h, const int w, - const T* input, T* output) { +__global__ void BoxRestoreKernel(const int n, const int h, const int w, + const T* input, T* output) { int id_n = threadIdx.x + blockDim.x * blockIdx.x; int id_h = threadIdx.y + blockDim.y * blockIdx.y; int id_w = threadIdx.z + blockDim.z * blockIdx.z; @@ -40,7 +40,7 @@ __global__ void PolygonRestoreKernel(const int n, const int h, const int w, } template -class PolygonRestoreOpCUDAKernel : public framework::OpKernel { +class BoxRestoreOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -62,7 +62,7 @@ class PolygonRestoreOpCUDAKernel : public framework::OpKernel { (height + threadsPerBlock.y - 1) / threadsPerBlock.y, (width + threadsPerBlock.z - 1) / threadsPerBlock.z); auto stream = ctx.cuda_device_context().stream(); - PolygonRestoreKernel<<>>( + BoxRestoreKernel<<>>( batch_size * geo_channels, height, width, in_data, out_data); } }; @@ -70,6 +70,6 @@ class PolygonRestoreOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(polygon_restore, - paddle::operators::PolygonRestoreOpCUDAKernel, - paddle::operators::PolygonRestoreOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(box_restore, + paddle::operators::BoxRestoreOpCUDAKernel, + paddle::operators::BoxRestoreOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_box_restore.py b/python/paddle/fluid/tests/unittests/test_box_restore.py index 333bb961778fd5..96ea9bd9b82c45 100644 --- a/python/paddle/fluid/tests/unittests/test_box_restore.py +++ b/python/paddle/fluid/tests/unittests/test_box_restore.py @@ -17,7 +17,7 @@ from op_test import OpTest -def PolygonRestore(input): +def BoxRestore(input): shape = input.shape batch_size = shape[0] geo_channels = shape[1] @@ -38,28 +38,28 @@ def PolygonRestore(input): input.shape) - input # [batch_size, geo_channels, h, w] -class TestPolygonRestoreOp(OpTest): +class TestBoxRestoreOp(OpTest): def config(self): self.input_shape = (1, 8, 2, 2) def setUp(self): self.config() - self.op_type = "polygon_restore" + self.op_type = "box_restore" input = np.random.random(self.input_shape).astype("float32") self.inputs = {'Input': input} - output = PolygonRestore(input) + output = BoxRestore(input) self.outputs = {'Output': output} def test_check_output(self): self.check_output() -class TestCase1(TestPolygonRestoreOp): +class TestCase1(TestBoxRestoreOp): def config(self): self.input_shape = (2, 10, 3, 2) -class TestCase2(TestPolygonRestoreOp): +class TestCase2(TestBoxRestoreOp): def config(self): self.input_shape = (3, 12, 4, 5) From a35d457b832897ff60508401daa61b5af2bc98e9 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 25 May 2018 15:02:45 +0000 Subject: [PATCH 9/9] Refine code: 1. rename op 2. uncomment unitest on GPU --- .../fluid/operators/detection/CMakeLists.txt | 4 +-- ...tore_op.cc => polygon_box_transform_op.cc} | 30 +++++++++++-------- ...tore_op.cu => polygon_box_transform_op.cu} | 15 +++++----- .../paddle/fluid/tests/unittests/op_test.py | 5 ++-- ...store.py => test_polygon_box_transform.py} | 12 ++++---- 5 files changed, 35 insertions(+), 31 deletions(-) rename paddle/fluid/operators/detection/{box_restore_op.cc => polygon_box_transform_op.cc} (73%) rename paddle/fluid/operators/detection/{box_restore_op.cu => polygon_box_transform_op.cu} (83%) rename python/paddle/fluid/tests/unittests/{test_box_restore.py => test_polygon_box_transform.py} (88%) diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index ced267a8ae9fbc..20d960f9fee1ea 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -24,8 +24,8 @@ detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc) detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu) detection_library(target_assign_op SRCS target_assign_op.cc target_assign_op.cu) -detection_library(box_restore_op SRCS box_restore_op.cc - box_restore_op.cu) +detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc + polygon_box_transform_op.cu) # Export local libraries to parent set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/detection/box_restore_op.cc b/paddle/fluid/operators/detection/polygon_box_transform_op.cc similarity index 73% rename from paddle/fluid/operators/detection/box_restore_op.cc rename to paddle/fluid/operators/detection/polygon_box_transform_op.cc index 0ac049529ba617..335e8dd470f851 100644 --- a/paddle/fluid/operators/detection/box_restore_op.cc +++ b/paddle/fluid/operators/detection/polygon_box_transform_op.cc @@ -20,7 +20,7 @@ namespace operators { using Tensor = framework::Tensor; template -class BoxRestoreCPUKernel : public framework::OpKernel { +class PolygonBoxTransformCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), @@ -51,15 +51,17 @@ class BoxRestoreCPUKernel : public framework::OpKernel { } }; -class BoxRestoreOp : public framework::OperatorWithKernel { +class PolygonBoxTransformOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input (Input) of box restore op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output (Output) of box restore op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Input"), + "Input (Input) of polygon_box transform op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Output"), + "Output (Output) of polygon_box transform op should not be null."); auto in_dim = ctx->GetInputDim("Input"); @@ -71,7 +73,7 @@ class BoxRestoreOp : public framework::OperatorWithKernel { } }; -class BoxRestoreOpMaker : public framework::OpProtoAndCheckerMaker { +class PolygonBoxTransformOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput( @@ -80,12 +82,12 @@ class BoxRestoreOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Output", "The output with the same shape as input"); AddComment(R"DOC( -BoxRestore Operator. +PolygonBoxTransform Operator. The input is the final geometry output in detection network. We use 2*n numbers to denote the coordinate shift from n corner vertices of -the box to the pixel location. As each distance offset contains two numbers (xi, yi), +the polygon_box to the pixel location. As each distance offset contains two numbers (xi, yi), the geometry output contains 2*n channels. -BoxRestore Operator is used to transform the coordinate shift to the real coordinate. +PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate. )DOC"); } }; @@ -94,8 +96,10 @@ BoxRestore Operator is used to transform the coordinate shift to the real coordi } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(box_restore, ops::BoxRestoreOp, ops::BoxRestoreOpMaker, +REGISTER_OPERATOR(polygon_box_transform, ops::PolygonBoxTransformOp, + ops::PolygonBoxTransformOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - box_restore, ops::BoxRestoreCPUKernel, - ops::BoxRestoreCPUKernel); + polygon_box_transform, + ops::PolygonBoxTransformCPUKernel, + ops::PolygonBoxTransformCPUKernel); diff --git a/paddle/fluid/operators/detection/box_restore_op.cu b/paddle/fluid/operators/detection/polygon_box_transform_op.cu similarity index 83% rename from paddle/fluid/operators/detection/box_restore_op.cu rename to paddle/fluid/operators/detection/polygon_box_transform_op.cu index c8ff262fb7f72b..6187ac6622c65d 100644 --- a/paddle/fluid/operators/detection/box_restore_op.cu +++ b/paddle/fluid/operators/detection/polygon_box_transform_op.cu @@ -24,8 +24,8 @@ using platform::PADDLE_CUDA_NUM_THREADS; #define CUDA_BLOCK_SIZE 16 template -__global__ void BoxRestoreKernel(const int n, const int h, const int w, - const T* input, T* output) { +__global__ void PolygonBoxTransformKernel(const int n, const int h, const int w, + const T* input, T* output) { int id_n = threadIdx.x + blockDim.x * blockIdx.x; int id_h = threadIdx.y + blockDim.y * blockIdx.y; int id_w = threadIdx.z + blockDim.z * blockIdx.z; @@ -40,7 +40,7 @@ __global__ void BoxRestoreKernel(const int n, const int h, const int w, } template -class BoxRestoreOpCUDAKernel : public framework::OpKernel { +class PolygonBoxTransformOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -62,7 +62,7 @@ class BoxRestoreOpCUDAKernel : public framework::OpKernel { (height + threadsPerBlock.y - 1) / threadsPerBlock.y, (width + threadsPerBlock.z - 1) / threadsPerBlock.z); auto stream = ctx.cuda_device_context().stream(); - BoxRestoreKernel<<>>( + PolygonBoxTransformKernel<<>>( batch_size * geo_channels, height, width, in_data, out_data); } }; @@ -70,6 +70,7 @@ class BoxRestoreOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(box_restore, - paddle::operators::BoxRestoreOpCUDAKernel, - paddle::operators::BoxRestoreOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + polygon_box_transform, + paddle::operators::PolygonBoxTransformOpCUDAKernel, + paddle::operators::PolygonBoxTransformOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index d41e253ed82bfa..269f1ddc2968be 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -342,9 +342,8 @@ def find_actual(target_name, fetch_list): def check_output(self, atol=1e-5): places = [core.CPUPlace()] - # if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): - - # places.append(core.CUDAPlace(0)) + if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): + places.append(core.CUDAPlace(0)) for place in places: self.check_output_with_place(place, atol) diff --git a/python/paddle/fluid/tests/unittests/test_box_restore.py b/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py similarity index 88% rename from python/paddle/fluid/tests/unittests/test_box_restore.py rename to python/paddle/fluid/tests/unittests/test_polygon_box_transform.py index 96ea9bd9b82c45..2105d320665367 100644 --- a/python/paddle/fluid/tests/unittests/test_box_restore.py +++ b/python/paddle/fluid/tests/unittests/test_polygon_box_transform.py @@ -17,7 +17,7 @@ from op_test import OpTest -def BoxRestore(input): +def PolygonBoxRestore(input): shape = input.shape batch_size = shape[0] geo_channels = shape[1] @@ -38,28 +38,28 @@ def BoxRestore(input): input.shape) - input # [batch_size, geo_channels, h, w] -class TestBoxRestoreOp(OpTest): +class TestPolygonBoxRestoreOp(OpTest): def config(self): self.input_shape = (1, 8, 2, 2) def setUp(self): self.config() - self.op_type = "box_restore" + self.op_type = "polygon_box_transform" input = np.random.random(self.input_shape).astype("float32") self.inputs = {'Input': input} - output = BoxRestore(input) + output = PolygonBoxRestore(input) self.outputs = {'Output': output} def test_check_output(self): self.check_output() -class TestCase1(TestBoxRestoreOp): +class TestCase1(TestPolygonBoxRestoreOp): def config(self): self.input_shape = (2, 10, 3, 2) -class TestCase2(TestBoxRestoreOp): +class TestCase2(TestPolygonBoxRestoreOp): def config(self): self.input_shape = (3, 12, 4, 5)