Skip to content

Commit 7bf84e2

Browse files
authored
add argmax and iou_similarity for kunlun (#35836)
* add argmax and iou_similarity for kunlun * add argmax and iou_similarity for kunlun * add argmax and iou_similarity for kunlun
1 parent 1548407 commit 7bf84e2

File tree

8 files changed

+449
-5
lines changed

8 files changed

+449
-5
lines changed

cmake/external/xpu.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ ELSE ()
3535
ENDIF()
3636

3737
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
38-
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210909")
38+
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210917")
3939
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
4040
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
4141
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the Licnse. */
14+
15+
#ifdef PADDLE_WITH_XPU
16+
17+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
using Tensor = framework::Tensor;
22+
23+
template <typename DeviceContext, typename T>
24+
class ArgMaxXPUKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
auto* x = ctx.Input<framework::LoDTensor>("X");
28+
auto* out = ctx.Output<framework::LoDTensor>("Out");
29+
auto dtype = ctx.Attr<int>("dtype");
30+
PADDLE_ENFORCE_EQ(
31+
(dtype < 0 || dtype == 3), true,
32+
platform::errors::InvalidArgument(
33+
"The attribute of dtype in xpu argmin/argmax must be [%s], but "
34+
"received [%s]",
35+
paddle::framework::DataTypeToString(
36+
framework::proto::VarType::INT64),
37+
paddle::framework::DataTypeToString(
38+
static_cast<framework::proto::VarType::Type>(dtype))));
39+
40+
out->template mutable_data<int64_t>(ctx.GetPlace());
41+
auto axis = ctx.Attr<int64_t>("axis");
42+
const bool& flatten = ctx.Attr<bool>("flatten");
43+
framework::DDim x_dims;
44+
if (flatten) {
45+
x_dims = framework::make_ddim({x->numel()});
46+
// if flatten, the axis just as 0
47+
axis = 0;
48+
} else {
49+
x_dims = x->dims();
50+
if (axis < 0) axis += x_dims.size();
51+
}
52+
auto xdims_vec = framework::vectorize<int>(x_dims);
53+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
54+
int r = xpu::argmax(dev_ctx.x_context(), x->data<T>(), out->data<int64_t>(),
55+
xdims_vec, axis);
56+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
57+
platform::errors::External(
58+
"XPU argmax kernel return wrong value[%d %s].", r,
59+
XPUAPIErrorMsg[r]));
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle
65+
66+
namespace ops = paddle::operators;
67+
REGISTER_OP_XPU_KERNEL(
68+
arg_max, ops::ArgMaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
69+
70+
#endif

paddle/fluid/operators/detection/CMakeLists.txt

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ endfunction()
1717

1818
detection_library(bipartite_match_op SRCS bipartite_match_op.cc)
1919
detection_library(box_coder_op SRCS box_coder_op.cc box_coder_op.cu)
20-
detection_library(iou_similarity_op SRCS iou_similarity_op.cc
21-
iou_similarity_op.cu)
2220
detection_library(mine_hard_examples_op SRCS mine_hard_examples_op.cc)
2321
detection_library(prior_box_op SRCS prior_box_op.cc prior_box_op.cu)
2422
detection_library(density_prior_box_op SRCS density_prior_box_op.cc density_prior_box_op.cu)
@@ -58,6 +56,12 @@ else()
5856
detection_library(collect_fpn_proposals_op SRCS collect_fpn_proposals_op.cc)
5957
endif()
6058

59+
if(WITH_XPU)
60+
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op_xpu.cc)
61+
else()
62+
detection_library(iou_similarity_op SRCS iou_similarity_op.cc iou_similarity_op.cu)
63+
endif()
64+
6165
detection_library(roi_perspective_transform_op SRCS roi_perspective_transform_op.cc roi_perspective_transform_op.cu)
6266
#Export local libraries to parent
6367
# set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_XPU
16+
17+
#include "paddle/fluid/operators/detection/iou_similarity_op.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename DeviceContext, typename T>
23+
class XPUIOUSimilarityKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
const framework::LoDTensor* in_x = ctx.Input<framework::LoDTensor>("X");
27+
const framework::Tensor* in_y = ctx.Input<framework::Tensor>("Y");
28+
bool normalized = ctx.Attr<bool>("box_normalized");
29+
framework::LoDTensor* out = ctx.Output<framework::LoDTensor>("Out");
30+
31+
int x_n = in_x->dims()[0];
32+
int y_n = in_y->dims()[0];
33+
T eps = static_cast<T>(1e-10);
34+
35+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
36+
int r = xpu::iou_similarity(
37+
dev_ctx.x_context(), in_x->data<T>(), in_y->data<T>(),
38+
out->mutable_data<T>(ctx.GetPlace()), x_n, y_n, eps, normalized);
39+
PADDLE_ENFORCE_EQ(
40+
r, XPU_SUCCESS,
41+
platform::errors::External(
42+
"XPU iou_similarity kernel return wrong value[%d %s].", r,
43+
XPUAPIErrorMsg[r]));
44+
}
45+
};
46+
47+
} // namespace operators
48+
} // namespace paddle
49+
50+
namespace ops = paddle::operators;
51+
using XPU = paddle::platform::XPUDeviceContext;
52+
53+
REGISTER_OP_XPU_KERNEL(iou_similarity, ops::XPUIOUSimilarityKernel<XPU, float>);
54+
55+
#endif

paddle/fluid/platform/xpu/xpu1_op_list.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ XPUOpMap& get_kl1_ops() {
318318
pOpKernelType(vartype::INT8, XPUPlace()),
319319
pOpKernelType(vartype::UINT8, XPUPlace()),
320320
pOpKernelType(vartype::FP32, XPUPlace())})},
321-
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
321+
{"momuntem", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
322+
{"iou_similarity",
323+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
324+
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
322325
// AddMore
323326
};
324327

paddle/fluid/platform/xpu/xpu2_op_list.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ XPUOpMap& get_kl2_ops() {
107107
{"transpose2_grad",
108108
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
109109
pOpKernelType(vartype::FP16, XPUPlace())})},
110-
110+
{"iou_similarity",
111+
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
112+
{"arg_max", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}
111113
// AddMore
112114
};
113115

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import sys
20+
sys.path.append("..")
21+
from op_test import OpTest
22+
from op_test_xpu import XPUOpTest
23+
import paddle
24+
import paddle.fluid.core as core
25+
26+
paddle.enable_static()
27+
28+
29+
class XPUBaseTestCase(XPUOpTest):
30+
def initTestCase(self):
31+
self.dims = (3, 4)
32+
self.dtype = 'float32'
33+
self.axis = 1
34+
35+
def setUp(self):
36+
self.initTestCase()
37+
self.__class__.op_type = 'arg_max'
38+
self.__class__.use_xpu = True
39+
np.random.seed(2021)
40+
self.x = (np.random.random(self.dims)).astype(self.dtype)
41+
self.inputs = {'X': self.x}
42+
self.attrs = {'axis': self.axis, 'use_xpu': True}
43+
if self.op_type == "arg_min":
44+
self.outputs = {'Out': np.argmin(self.x, axis=self.axis)}
45+
else:
46+
self.outputs = {'Out': np.argmax(self.x, axis=self.axis)}
47+
48+
def test_check_output(self):
49+
if paddle.is_compiled_with_xpu():
50+
place = paddle.XPUPlace(0)
51+
self.check_output_with_place(place)
52+
53+
54+
# test argmax, dtype: float32
55+
class TestArgMaxFloat32Case1(XPUBaseTestCase):
56+
def initTestCase(self):
57+
self.op_type = 'arg_max'
58+
self.dims = (3, 4, 5)
59+
self.dtype = 'float32'
60+
self.axis = -1
61+
62+
63+
class TestArgMaxFloat32Case2(XPUBaseTestCase):
64+
def initTestCase(self):
65+
self.op_type = 'arg_max'
66+
self.dims = (3, 4, 5)
67+
self.dtype = 'float32'
68+
self.axis = 0
69+
70+
71+
class TestArgMaxFloat32Case3(XPUBaseTestCase):
72+
def initTestCase(self):
73+
self.op_type = 'arg_max'
74+
self.dims = (3, 4, 5)
75+
self.dtype = 'float32'
76+
self.axis = 1
77+
78+
79+
class TestArgMaxFloat32Case4(XPUBaseTestCase):
80+
def initTestCase(self):
81+
self.op_type = 'arg_max'
82+
self.dims = (3, 4, 5)
83+
self.dtype = 'float32'
84+
self.axis = 2
85+
86+
87+
class TestArgMaxFloat32Case5(XPUBaseTestCase):
88+
def initTestCase(self):
89+
self.op_type = 'arg_max'
90+
self.dims = (3, 4)
91+
self.dtype = 'float32'
92+
self.axis = -1
93+
94+
95+
class TestArgMaxFloat32Case6(XPUBaseTestCase):
96+
def initTestCase(self):
97+
self.op_type = 'arg_max'
98+
self.dims = (3, 4)
99+
self.dtype = 'float32'
100+
self.axis = 0
101+
102+
103+
class TestArgMaxFloat32Case7(XPUBaseTestCase):
104+
def initTestCase(self):
105+
self.op_type = 'arg_max'
106+
self.dims = (3, 4)
107+
self.dtype = 'float32'
108+
self.axis = 1
109+
110+
111+
class TestArgMaxFloat32Case8(XPUBaseTestCase):
112+
def initTestCase(self):
113+
self.op_type = 'arg_max'
114+
self.dims = (1, )
115+
self.dtype = 'float32'
116+
self.axis = 0
117+
118+
119+
class TestArgMaxFloat32Case9(XPUBaseTestCase):
120+
def initTestCase(self):
121+
self.op_type = 'arg_max'
122+
self.dims = (2, )
123+
self.dtype = 'float32'
124+
self.axis = 0
125+
126+
127+
class TestArgMaxFloat32Case10(XPUBaseTestCase):
128+
def initTestCase(self):
129+
self.op_type = 'arg_max'
130+
self.dims = (3, )
131+
self.dtype = 'float32'
132+
self.axis = 0
133+
134+
135+
class TestArgMaxAPI(unittest.TestCase):
136+
def initTestCase(self):
137+
self.dims = (3, 4, 5)
138+
self.dtype = 'float32'
139+
self.axis = 0
140+
141+
def setUp(self):
142+
self.initTestCase()
143+
self.__class__.use_Xpu = True
144+
self.place = [paddle.XPUPlace(0)]
145+
146+
def test_dygraph_api(self):
147+
def run(place):
148+
paddle.disable_static(place)
149+
np.random.seed(2021)
150+
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
151+
tensor_input = paddle.to_tensor(numpy_input)
152+
numpy_output = np.argmax(numpy_input, axis=self.axis)
153+
paddle_output = paddle.argmax(tensor_input, axis=self.axis)
154+
self.assertEqual(
155+
np.allclose(numpy_output, paddle_output.numpy()), True)
156+
paddle.enable_static()
157+
158+
for place in self.place:
159+
run(place)
160+
161+
162+
class TestArgMaxAPI_2(unittest.TestCase):
163+
def initTestCase(self):
164+
self.dims = (3, 4, 5)
165+
self.dtype = 'float32'
166+
self.axis = 0
167+
self.keep_dims = True
168+
169+
def setUp(self):
170+
self.initTestCase()
171+
self.__class__.use_xpu = True
172+
self.place = [paddle.XPUPlace(0)]
173+
174+
def test_dygraph_api(self):
175+
def run(place):
176+
paddle.disable_static(place)
177+
np.random.seed(2021)
178+
numpy_input = (np.random.random(self.dims)).astype(self.dtype)
179+
tensor_input = paddle.to_tensor(numpy_input)
180+
numpy_output = np.argmax(
181+
numpy_input, axis=self.axis).reshape(1, 4, 5)
182+
paddle_output = paddle.argmax(
183+
tensor_input, axis=self.axis, keepdim=self.keep_dims)
184+
self.assertEqual(
185+
np.allclose(numpy_output, paddle_output.numpy()), True)
186+
self.assertEqual(numpy_output.shape, paddle_output.numpy().shape)
187+
paddle.enable_static()
188+
189+
for place in self.place:
190+
run(place)
191+
192+
193+
if __name__ == '__main__':
194+
unittest.main()

0 commit comments

Comments
 (0)