Skip to content

Commit 376c948

Browse files
wanghaoshuangqingqing01
authored andcommitted
Polygon box transform op for OCR East detection. (#10802)
* Add quad transform. * Fix some syntax error. * Fix CUDA kernel launch configure. * Generalize geometry channels. * Rename QuadTransform to PolygonRestore. * Rename op. * Rename op and fix computation. * Modify CMakeLists.txt for box_restore op. * Refine code: 1. rename op 2. uncomment unitest on GPU
1 parent a62bbd1 commit 376c948

File tree

5 files changed

+255
-4
lines changed

5 files changed

+255
-4
lines changed

paddle/fluid/operators/detection/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc)
2424
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
2525
detection_library(target_assign_op SRCS target_assign_op.cc
2626
target_assign_op.cu)
27+
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
28+
polygon_box_transform_op.cu)
2729

2830
# Export local libraries to parent
2931
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
template <typename DeviceContext, typename T>
23+
class PolygonBoxTransformCPUKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
27+
"It must use CUDAPlace.");
28+
auto* in = ctx.Input<Tensor>("Input");
29+
auto in_dims = in->dims();
30+
const T* in_data = in->data<T>();
31+
auto* out = ctx.Output<Tensor>("Output");
32+
T* out_data = out->mutable_data<T>(ctx.GetPlace());
33+
34+
int batch_size = in_dims[0];
35+
int geo_channel = in_dims[1];
36+
int height = in_dims[2];
37+
int width = in_dims[3];
38+
int id = 0;
39+
for (int id_n = 0; id_n < batch_size * geo_channel; ++id_n) {
40+
for (int id_h = 0; id_h < height; ++id_h) {
41+
for (int id_w = 0; id_w < width; ++id_w) {
42+
id = id_n * height * width + width * id_h + id_w;
43+
if (id_n % 2 == 0) {
44+
out_data[id] = id_w - in_data[id];
45+
} else {
46+
out_data[id] = id_h - in_data[id];
47+
}
48+
}
49+
}
50+
}
51+
}
52+
};
53+
54+
class PolygonBoxTransformOp : public framework::OperatorWithKernel {
55+
public:
56+
using framework::OperatorWithKernel::OperatorWithKernel;
57+
58+
void InferShape(framework::InferShapeContext* ctx) const override {
59+
PADDLE_ENFORCE(
60+
ctx->HasInput("Input"),
61+
"Input (Input) of polygon_box transform op should not be null.");
62+
PADDLE_ENFORCE(
63+
ctx->HasOutput("Output"),
64+
"Output (Output) of polygon_box transform op should not be null.");
65+
66+
auto in_dim = ctx->GetInputDim("Input");
67+
68+
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "input's rank must be 4.");
69+
PADDLE_ENFORCE_EQ(in_dim[1] % 2, 0,
70+
"input's second dimension must be even.");
71+
72+
ctx->SetOutputDim("Output", in_dim);
73+
}
74+
};
75+
76+
class PolygonBoxTransformOpMaker : public framework::OpProtoAndCheckerMaker {
77+
public:
78+
void Make() override {
79+
AddInput(
80+
"Input",
81+
"The input with shape [batch_size, geometry_channels, height, width]");
82+
AddOutput("Output", "The output with the same shape as input");
83+
84+
AddComment(R"DOC(
85+
PolygonBoxTransform Operator.
86+
The input is the final geometry output in detection network.
87+
We use 2*n numbers to denote the coordinate shift from n corner vertices of
88+
the polygon_box to the pixel location. As each distance offset contains two numbers (xi, yi),
89+
the geometry output contains 2*n channels.
90+
PolygonBoxTransform Operator is used to transform the coordinate shift to the real coordinate.
91+
)DOC");
92+
}
93+
};
94+
95+
} // namespace operators
96+
} // namespace paddle
97+
98+
namespace ops = paddle::operators;
99+
REGISTER_OPERATOR(polygon_box_transform, ops::PolygonBoxTransformOp,
100+
ops::PolygonBoxTransformOpMaker,
101+
paddle::framework::EmptyGradOpMaker);
102+
REGISTER_OP_CPU_KERNEL(
103+
polygon_box_transform,
104+
ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, float>,
105+
ops::PolygonBoxTransformCPUKernel<paddle::platform::CPUPlace, double>);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/platform/cuda_primitives.h"
17+
#include "paddle/fluid/platform/gpu_info.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
using platform::PADDLE_CUDA_NUM_THREADS;
24+
#define CUDA_BLOCK_SIZE 16
25+
26+
template <typename T>
27+
__global__ void PolygonBoxTransformKernel(const int n, const int h, const int w,
28+
const T* input, T* output) {
29+
int id_n = threadIdx.x + blockDim.x * blockIdx.x;
30+
int id_h = threadIdx.y + blockDim.y * blockIdx.y;
31+
int id_w = threadIdx.z + blockDim.z * blockIdx.z;
32+
if (id_n < n && id_h < h && id_w < w) {
33+
int id = id_n * h * w + w * id_h + id_w;
34+
if (id_n % 2 == 0) {
35+
output[id] = id_w - input[id];
36+
} else {
37+
output[id] = id_h - input[id];
38+
}
39+
}
40+
}
41+
42+
template <typename T>
43+
class PolygonBoxTransformOpCUDAKernel : public framework::OpKernel<T> {
44+
public:
45+
void Compute(const framework::ExecutionContext& ctx) const override {
46+
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
47+
"It must use CUDAPlace.");
48+
auto* in = ctx.Input<Tensor>("Input");
49+
auto in_dims = in->dims();
50+
const T* in_data = in->data<T>();
51+
auto* out = ctx.Output<Tensor>("Output");
52+
T* out_data = out->mutable_data<T>(ctx.GetPlace());
53+
54+
int batch_size = in_dims[0];
55+
int geo_channels = in_dims[1];
56+
int height = in_dims[2];
57+
int width = in_dims[3];
58+
dim3 threadsPerBlock(
59+
PADDLE_CUDA_NUM_THREADS / (CUDA_BLOCK_SIZE * CUDA_BLOCK_SIZE),
60+
CUDA_BLOCK_SIZE, CUDA_BLOCK_SIZE);
61+
dim3 numBlocks((batch_size * geo_channels) / threadsPerBlock.x,
62+
(height + threadsPerBlock.y - 1) / threadsPerBlock.y,
63+
(width + threadsPerBlock.z - 1) / threadsPerBlock.z);
64+
auto stream = ctx.cuda_device_context().stream();
65+
PolygonBoxTransformKernel<T><<<numBlocks, threadsPerBlock, 0, stream>>>(
66+
batch_size * geo_channels, height, width, in_data, out_data);
67+
}
68+
};
69+
70+
} // namespace operators
71+
} // namespace paddle
72+
73+
REGISTER_OP_CUDA_KERNEL(
74+
polygon_box_transform,
75+
paddle::operators::PolygonBoxTransformOpCUDAKernel<float>,
76+
paddle::operators::PolygonBoxTransformOpCUDAKernel<double>);

python/paddle/fluid/tests/unittests/op_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -479,19 +479,19 @@ def _numpy_to_lod_tensor(np_value, lod, place):
479479
def np_dtype_to_fluid_dtype(input):
480480
"""Change the dtype of float16 numpy array
481481
482-
numpy float16 is binded to paddle::platform::float16
482+
numpy float16 is binded to paddle::platform::float16
483483
in tensor_py.h via the help of uint16 data type since
484-
the internal memory representation of float16 is
484+
the internal memory representation of float16 is
485485
uint16_t in paddle and np.uint16 in numpy, which are
486486
themselves binded together by pybind.
487487
488488
Args:
489489
input: input numpy array
490490
491491
Returns:
492-
input: The dtype of input will be changed to np.uint16 if
492+
input: The dtype of input will be changed to np.uint16 if
493493
it is originally np.float16, such that the internal memory
494-
of input will be reinterpreted as of dtype np.uint16.
494+
of input will be reinterpreted as of dtype np.uint16.
495495
"""
496496
if input.dtype == np.float16:
497497
input.dtype = np.uint16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
from op_test import OpTest
18+
19+
20+
def PolygonBoxRestore(input):
21+
shape = input.shape
22+
batch_size = shape[0]
23+
geo_channels = shape[1]
24+
h = shape[2]
25+
w = shape[3]
26+
h_indexes = np.array(range(h) * w).reshape(
27+
[w, h]).transpose()[np.newaxis, :] # [1, h, w]
28+
w_indexes = np.array(range(w) * h).reshape(
29+
[h, w])[np.newaxis, :] # [1, h, w]
30+
indexes = np.concatenate(
31+
(w_indexes, h_indexes))[np.newaxis, :] # [1, 2, h, w]
32+
indexes = indexes.repeat(
33+
[geo_channels / 2],
34+
axis=0)[np.newaxis, :] # [1, geo_channels/2, 2, h, w]
35+
indexes = indexes.repeat(
36+
[batch_size], axis=0) # [batch_size, geo_channels/2, 2, h, w]
37+
return indexes.reshape(
38+
input.shape) - input # [batch_size, geo_channels, h, w]
39+
40+
41+
class TestPolygonBoxRestoreOp(OpTest):
42+
def config(self):
43+
self.input_shape = (1, 8, 2, 2)
44+
45+
def setUp(self):
46+
self.config()
47+
self.op_type = "polygon_box_transform"
48+
input = np.random.random(self.input_shape).astype("float32")
49+
self.inputs = {'Input': input}
50+
output = PolygonBoxRestore(input)
51+
self.outputs = {'Output': output}
52+
53+
def test_check_output(self):
54+
self.check_output()
55+
56+
57+
class TestCase1(TestPolygonBoxRestoreOp):
58+
def config(self):
59+
self.input_shape = (2, 10, 3, 2)
60+
61+
62+
class TestCase2(TestPolygonBoxRestoreOp):
63+
def config(self):
64+
self.input_shape = (3, 12, 4, 5)
65+
66+
67+
if __name__ == '__main__':
68+
unittest.main()

0 commit comments

Comments
 (0)