Skip to content

Commit 1dcf7c1

Browse files
authored
[NPU] reduce_prod modify (PaddlePaddle#234)
1 parent e77e5b0 commit 1dcf7c1

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

backends/npu/kernels/reduce_prod_kernel.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ void ProdRawKernel(const Context& dev_ctx,
2525
bool reduce_all,
2626
phi::DenseTensor* out) {
2727
auto dims = axes.GetData();
28+
auto x_dims = x.dims();
29+
auto x_dims_size = x_dims.size();
2830
dev_ctx.template Alloc<T>(out);
2931

3032
NPUAttributeMap attr_input = {{"axes", dims}, {"keep_dims", keep_dim}};
3133

3234
if (reduce_all) {
3335
std::vector<int> dim_vec;
34-
for (int i = 0; i < x.dims().size(); i++) {
36+
for (int i = 0; i < x_dims_size; i++) {
3537
dim_vec.push_back(i);
3638
}
3739

@@ -56,8 +58,37 @@ void ProdRawKernel(const Context& dev_ctx,
5658
{phi::DenseTensorMeta::DataType::INT32},
5759
{phi::DenseTensorMeta::DataType::INT32});
5860
} else {
59-
const auto& runner = NpuOpRunner("ReduceProdD", {x}, {*out}, attr_input);
60-
runner.Run(dev_ctx.stream());
61+
// TODO(Aganlengzi): remove this branch when performance of ReduceProdD
62+
// is good enough for big shapes.
63+
// Here, we use SplitV and Mul to deal with special cases.
64+
if (x_dims[x_dims_size - 1] == 2 && dims.size() == 1 &&
65+
(dims[0] == -1 || dims[0] == x_dims_size - 1)) {
66+
auto stream = dev_ctx.stream();
67+
phi::DenseTensor x1, x2;
68+
x1.set_meta(out->meta());
69+
x2.set_meta(out->meta());
70+
dev_ctx.template Alloc<T>(&x1);
71+
dev_ctx.template Alloc<T>(&x2);
72+
// split
73+
std::vector<phi::DenseTensor> outputs;
74+
outputs.push_back(x1);
75+
outputs.push_back(x2);
76+
std::vector<int> sections = {1, 1};
77+
NpuOpRunner runner_split;
78+
runner_split.SetType("SplitV")
79+
.AddInput(x)
80+
.AddInput(dev_ctx, std::move(sections))
81+
.AddInput(dev_ctx, std::vector<int32_t>({-1}))
82+
.AddOutputs(outputs)
83+
.AddAttrs({{"num_split", static_cast<int32_t>(sections.size())}})
84+
.Run(stream);
85+
// elementwise mul
86+
const auto& runner = NpuOpRunner("Mul", {x1, x2}, {*out}, {});
87+
runner.Run(stream);
88+
} else {
89+
const auto& runner = NpuOpRunner("ReduceProdD", {x}, {*out}, attr_input);
90+
runner.Run(dev_ctx.stream());
91+
}
6192
}
6293
}
6394

backends/npu/tests/unittests/test_reduce_prod_op_npu.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def setUp(self):
7171
self.outputs = {'Out': self.inputs['X'].prod(axis=tuple([0]))}
7272

7373

74+
class TestNPUReduceProd4(TestNPUReduceProd):
75+
def setUp(self):
76+
self.op_type = "reduce_prod"
77+
self.set_npu()
78+
self.init_dtype()
79+
80+
self.inputs = {
81+
'X': np.random.random((32, 888, 50, 2)).astype(self.dtype)
82+
}
83+
self.attrs = {'dim': [-1]}
84+
self.outputs = {'Out': self.inputs['X'].prod(axis=tuple([-1]))}
85+
86+
7487
class TestNPUReduceProd6D(TestNPUReduceProd):
7588
def setUp(self):
7689
self.op_type = "reduce_prod"

0 commit comments

Comments
 (0)