Skip to content

Commit 2f4a262

Browse files
authored
[npu] add cumprod (PaddlePaddle#1377)
1 parent eb2fe9e commit 2f4a262

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "kernels/funcs/npu_funcs.h"
16+
#include "kernels/funcs/npu_op_runner.h"
17+
18+
namespace custom_kernel {
19+
20+
template <typename T, typename Context>
21+
void CumprodKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& input,
23+
const int dim,
24+
bool exclusive,
25+
bool reverse,
26+
phi::DenseTensor* out) {
27+
auto stream = dev_ctx.stream();
28+
dev_ctx.template Alloc<T>(out);
29+
30+
NPUAttributeMap attr_input = {
31+
{"axis", dim}, {"exclusive", exclusive}, {"reverse", reverse}};
32+
33+
const auto& runner = NpuOpRunner("Cumprod", {input}, {*out}, attr_input);
34+
runner.Run(stream);
35+
}
36+
} // namespace custom_kernel
37+
38+
PD_REGISTER_PLUGIN_KERNEL(cumprod,
39+
npu,
40+
ALL_LAYOUT,
41+
custom_kernel::CumprodKernel,
42+
float,
43+
double,
44+
int,
45+
int64_t,
46+
phi::dtype::complex<float>,
47+
phi::dtype::complex<double>) {}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
20+
import paddle
21+
import paddle.base.core as core
22+
import paddle.base as base
23+
24+
paddle.enable_static()
25+
26+
27+
class TestCumprodOp(unittest.TestCase):
28+
def run_cases(self):
29+
data_np = np.arange(12).reshape(3, 4)
30+
data = paddle.to_tensor(data_np)
31+
32+
y = paddle.cumprod(data)
33+
z = np.cumprod(data_np)
34+
self.assertTrue(np.array_equal(z, y.numpy()))
35+
36+
y = paddle.cumprod(data, dim=0)
37+
z = np.cumprod(data_np, axis=0)
38+
self.assertTrue(np.array_equal(z, y.numpy()))
39+
40+
y = paddle.cumprod(data, dim=-1)
41+
z = np.cumprod(data_np, axis=-1)
42+
self.assertTrue(np.array_equal(z, y.numpy()))
43+
44+
y = paddle.cumprod(data, dtype="float32")
45+
self.assertTrue(y.dtype == core.VarDesc.VarType.FP32)
46+
47+
y = paddle.cumprod(data, dtype="int32")
48+
self.assertTrue(y.dtype == core.VarDesc.VarType.INT32)
49+
50+
y = paddle.cumprod(data, dim=-2)
51+
z = np.cumprod(data_np, axis=-2)
52+
self.assertTrue(np.array_equal(z, y.numpy()))
53+
54+
def run_static(self, use_custom_device=False):
55+
with base.program_guard(base.Program()):
56+
data_np = np.random.random((100, 100)).astype(np.float32)
57+
x = paddle.static.data("X", [100, 100])
58+
y = paddle.cumprod(x, dim=0)
59+
y2 = paddle.cumprod(x, dim=1)
60+
y3 = paddle.cumprod(x, dim=-1)
61+
y4 = paddle.cumprod(x, dim=0, dtype="float32")
62+
y5 = paddle.cumprod(x, dim=0, dtype="int32")
63+
y6 = paddle.cumprod(x, dim=-2)
64+
65+
place = base.CustomPlace("npu", 0) if use_custom_device else base.CPUPlace()
66+
exe = base.Executor(place)
67+
exe.run(base.default_startup_program())
68+
out = exe.run(
69+
feed={"X": data_np},
70+
fetch_list=[y.name, y2.name, y3.name, y4.name, y5.name, y6.name],
71+
)
72+
73+
z = np.cumprod(data_np, axis=0)
74+
self.assertTrue(np.allclose(z, out[0]))
75+
z = np.cumprod(data_np, axis=1)
76+
self.assertTrue(np.allclose(z, out[1]))
77+
z = np.cumprod(data_np, axis=-1)
78+
self.assertTrue(np.allclose(z, out[2]))
79+
self.assertTrue(out[3].dtype == np.float32)
80+
self.assertTrue(out[4].dtype == np.int32)
81+
z = np.cumprod(data_np, axis=-2)
82+
self.assertTrue(np.allclose(z, out[5]))
83+
84+
def test_npu(self):
85+
# Now, npu tests need setting paddle.enable_static()
86+
87+
self.run_static(use_custom_device=True)
88+
89+
def test_name(self):
90+
with base.program_guard(base.Program()):
91+
x = paddle.static.data("x", [3, 4])
92+
y = paddle.cumprod(x, dim=0, name="out")
93+
self.assertTrue("out" in y.name)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)