@@ -25,13 +25,15 @@ void ProdRawKernel(const Context& dev_ctx,
25
25
bool reduce_all,
26
26
phi::DenseTensor* out) {
27
27
auto dims = axes.GetData ();
28
+ auto x_dims = x.dims ();
29
+ auto x_dims_size = x_dims.size ();
28
30
dev_ctx.template Alloc <T>(out);
29
31
30
32
NPUAttributeMap attr_input = {{" axes" , dims}, {" keep_dims" , keep_dim}};
31
33
32
34
if (reduce_all) {
33
35
std::vector<int > dim_vec;
34
- for (int i = 0 ; i < x. dims (). size () ; i++) {
36
+ for (int i = 0 ; i < x_dims_size ; i++) {
35
37
dim_vec.push_back (i);
36
38
}
37
39
@@ -56,8 +58,37 @@ void ProdRawKernel(const Context& dev_ctx,
56
58
{phi::DenseTensorMeta::DataType::INT32},
57
59
{phi::DenseTensorMeta::DataType::INT32});
58
60
} else {
59
- const auto & runner = NpuOpRunner (" ReduceProdD" , {x}, {*out}, attr_input);
60
- runner.Run (dev_ctx.stream ());
61
+ // TODO(Aganlengzi): remove this branch when performance of ReduceProdD
62
+ // is good enough for big shapes.
63
+ // Here, we use SplitV and Mul to deal with special cases.
64
+ if (x_dims[x_dims_size - 1 ] == 2 && dims.size () == 1 &&
65
+ (dims[0 ] == -1 || dims[0 ] == x_dims_size - 1 )) {
66
+ auto stream = dev_ctx.stream ();
67
+ phi::DenseTensor x1, x2;
68
+ x1.set_meta (out->meta ());
69
+ x2.set_meta (out->meta ());
70
+ dev_ctx.template Alloc <T>(&x1);
71
+ dev_ctx.template Alloc <T>(&x2);
72
+ // split
73
+ std::vector<phi::DenseTensor> outputs;
74
+ outputs.push_back (x1);
75
+ outputs.push_back (x2);
76
+ std::vector<int > sections = {1 , 1 };
77
+ NpuOpRunner runner_split;
78
+ runner_split.SetType (" SplitV" )
79
+ .AddInput (x)
80
+ .AddInput (dev_ctx, std::move (sections))
81
+ .AddInput (dev_ctx, std::vector<int32_t >({-1 }))
82
+ .AddOutputs (outputs)
83
+ .AddAttrs ({{" num_split" , static_cast <int32_t >(sections.size ())}})
84
+ .Run (stream);
85
+ // elementwise mul
86
+ const auto & runner = NpuOpRunner (" Mul" , {x1, x2}, {*out}, {});
87
+ runner.Run (stream);
88
+ } else {
89
+ const auto & runner = NpuOpRunner (" ReduceProdD" , {x}, {*out}, attr_input);
90
+ runner.Run (dev_ctx.stream ());
91
+ }
61
92
}
62
93
}
63
94
0 commit comments