Skip to content

Commit eb2fe9e

Browse files
authored
[npu]add randint (PaddlePaddle#1379)
1 parent 464727d commit eb2fe9e

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "kernels/funcs/npu_funcs.h"
16+
#include "kernels/funcs/npu_op_runner.h"
17+
18+
namespace custom_kernel {
19+
20+
template <typename T, typename Context>
21+
void RandintKernel(const Context& dev_ctx,
22+
const int low,
23+
const int high,
24+
const phi::IntArray& shape,
25+
phi::DataType dtype UNUSED,
26+
phi::DenseTensor* out) {
27+
out->Resize(common::make_ddim(shape.GetData()));
28+
dev_ctx.template Alloc<T>(out);
29+
int64_t low_ = low;
30+
int64_t high_ = high;
31+
int64_t seed = 0;
32+
int64_t offset = 0;
33+
EXEC_NPU_CMD(aclnnInplaceRandom, dev_ctx, *out, low_, high_, seed, offset);
34+
}
35+
} // namespace custom_kernel
36+
37+
PD_REGISTER_PLUGIN_KERNEL(
38+
randint, npu, ALL_LAYOUT, custom_kernel::RandintKernel, int, int64_t) {}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
20+
from tests.op_test import OpTest
21+
import paddle
22+
import paddle.base.core as core
23+
from paddle.static import program_guard, Program
24+
25+
paddle.enable_static()
26+
27+
28+
def check_randint_out(data, low, high):
29+
assert isinstance(data, np.ndarray), "The input data should be np.ndarray."
30+
mask = (data < low) | (data >= high)
31+
return not mask.any()
32+
33+
34+
def convert_dtype(dtype_str):
35+
dtype_str_list = ["int32", "int64", "float32", "float64"]
36+
dtype_num_list = [
37+
core.VarDesc.VarType.INT32,
38+
core.VarDesc.VarType.INT64,
39+
core.VarDesc.VarType.FP32,
40+
core.VarDesc.VarType.FP64,
41+
]
42+
assert dtype_str in dtype_str_list, dtype_str + " should in " + str(dtype_str_list)
43+
return dtype_num_list[dtype_str_list.index(dtype_str)]
44+
45+
46+
class TestRandintOp(OpTest):
47+
"""Test randint op."""
48+
49+
def setUp(self):
50+
self.set_npu()
51+
self.op_type = "randint"
52+
self.low = 0
53+
self.high = 10
54+
self.shape = [3, 3]
55+
self.dtype = "int64"
56+
57+
self.inputs = {}
58+
self.outputs = {"Out": np.zeros(self.shape).astype(self.dtype)}
59+
self.init_attrs()
60+
self.attrs = {
61+
"low": self.low,
62+
"high": self.high,
63+
"shape": self.shape,
64+
"dtype": convert_dtype(self.dtype),
65+
}
66+
67+
def set_npu(self):
68+
self.__class__.use_custom_device = True
69+
70+
def _get_places(self):
71+
return [paddle.CustomPlace("npu", 0)]
72+
73+
def init_attrs(self):
74+
pass
75+
76+
def test_check_output(self):
77+
self.check_output_customized(self.verify_output)
78+
79+
def verify_output(self, outs):
80+
out_np = np.array(outs[0])
81+
self.assertTrue(check_randint_out(out_np, self.low, self.high))
82+
83+
84+
class TestRandintOpLow(TestRandintOp):
85+
def init_attrs(self):
86+
self.low = -10
87+
88+
89+
class TestRandintOpHigh(TestRandintOp):
90+
def init_attrs(self):
91+
self.high = 5
92+
93+
94+
class TestRandintOpInt32(TestRandintOp):
95+
def init_attrs(self):
96+
self.dtype = "int32"
97+
98+
99+
class TestRandintAPI(unittest.TestCase):
100+
def test_out(self):
101+
low = -5
102+
high = 5
103+
shape = [2, 3]
104+
place = paddle.CustomPlace("npu", 0)
105+
with program_guard(Program(), Program()):
106+
x1 = paddle.randint(low, high, shape=shape)
107+
x2 = paddle.randint(low, high, shape=shape, dtype="int32")
108+
109+
exe = paddle.static.Executor(place)
110+
res = exe.run(fetch_list=[x1, x2])
111+
112+
self.assertEqual(res[0].dtype, np.int64)
113+
self.assertEqual(res[1].dtype, np.int32)
114+
self.assertTrue(check_randint_out(res[0], low, high))
115+
self.assertTrue(check_randint_out(res[1], low, high))
116+
117+
118+
class TestRandintImperative(unittest.TestCase):
119+
def test_out(self):
120+
paddle.disable_static(paddle.CustomPlace("npu", 0))
121+
low = -10
122+
high = 10
123+
shape = [5, 3]
124+
for dtype in ["int32", np.int64]:
125+
data_p = paddle.randint(low, high, shape=shape, dtype=dtype)
126+
data_np = data_p.numpy()
127+
self.assertTrue(check_randint_out(data_np, low, high))
128+
paddle.enable_static()
129+
130+
131+
if __name__ == "__main__":
132+
unittest.main()

0 commit comments

Comments
 (0)