Skip to content

Commit cef4ed5

Browse files
[NPU] GatherND and Gather question fix. (PaddlePaddle#1288)
1 parent b72edb8 commit cef4ed5

File tree

4 files changed

+88
-16
lines changed

4 files changed

+88
-16
lines changed

backends/npu/kernels/gather_kernel.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,8 @@ void GatherGradKernel(const Context& dev_ctx,
118118
zeroslike_xout.Resize(x.dims());
119119

120120
// step3: scatter(x_grad)
121-
const auto& runner_scatter = NpuOpRunner("TensorScatterUpdate",
122-
{zeroslike_xout, *p_index, out_grad},
123-
{*x_grad},
124-
{});
125-
runner_scatter.Run(stream);
121+
EXEC_NPU_CMD(
122+
aclnnScatterNd, dev_ctx, zeroslike_xout, *p_index, out_grad, *x_grad);
126123
}
127124

128125
} // namespace custom_kernel

backends/npu/kernels/gather_nd_kernel.cc

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,64 @@
1717

1818
namespace custom_kernel {
1919

20+
template <typename T, typename Context>
21+
void AclopGatherNdKernel(const Context &dev_ctx,
22+
const phi::DenseTensor &x,
23+
const phi::DenseTensor &index,
24+
phi::DenseTensor *out) {
25+
dev_ctx.template Alloc<T>(out);
26+
auto stream = dev_ctx.stream();
27+
28+
if (x.numel() == 0) return;
29+
30+
if (index.numel() == 0) {
31+
int diff = out->dims().size() - x.dims().size();
32+
if (diff == 0) {
33+
TensorCopy(dev_ctx, x, false, out);
34+
} else {
35+
std::vector<int64_t> new_dims(diff, 1);
36+
for (size_t i = 0; i < x.dims().size(); ++i) {
37+
new_dims.emplace_back(x.dims()[i]);
38+
}
39+
40+
phi::DenseTensor x_tmp(x);
41+
x_tmp.Resize(phi::make_ddim(new_dims));
42+
43+
NpuOpRunner runner;
44+
runner.SetType("BroadcastTo")
45+
.AddInput(x_tmp)
46+
.AddInput(dev_ctx, phi::vectorize(out->dims()))
47+
.AddOutput(*out);
48+
runner.Run(stream);
49+
}
50+
return;
51+
}
52+
53+
const auto &index_type = index.dtype();
54+
bool index_type_match =
55+
index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64;
56+
PADDLE_ENFORCE_EQ(
57+
index_type_match,
58+
true,
59+
phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s],"
60+
"but desires to be [%s] or [%s]",
61+
index_type,
62+
phi::DataType::INT32,
63+
phi::DataType::INT64));
64+
65+
const auto &runner = NpuOpRunner("GatherNd", {x, index}, {*out}, {});
66+
runner.Run(stream);
67+
}
68+
2069
template <typename T, typename Context>
2170
void GatherNdKernel(const Context &dev_ctx,
2271
const phi::DenseTensor &x,
2372
const phi::DenseTensor &index,
2473
phi::DenseTensor *out) {
74+
DO_COMPATIBILITY(
75+
aclnnGatherNd,
76+
(custom_kernel::AclopGatherNdKernel<T, Context>(dev_ctx, x, index, out)));
77+
2578
dev_ctx.template Alloc<T>(out);
2679
auto stream = dev_ctx.stream();
2780

@@ -62,8 +115,8 @@ void GatherNdKernel(const Context &dev_ctx,
62115
phi::DataType::INT32,
63116
phi::DataType::INT64));
64117

65-
const auto &runner = NpuOpRunner("GatherNd", {x, index}, {*out}, {});
66-
runner.Run(stream);
118+
bool negativeIndexSupport = false;
119+
EXEC_NPU_CMD(aclnnGatherNd, dev_ctx, x, index, negativeIndexSupport, *out);
67120
}
68121

69122
template <typename T, typename Context>
@@ -134,6 +187,7 @@ PD_REGISTER_PLUGIN_KERNEL(gather_nd,
134187
ALL_LAYOUT,
135188
custom_kernel::GatherNdKernel,
136189
int64_t,
190+
int32_t,
137191
float,
138192
phi::dtype::float16,
139193
phi::dtype::bfloat16) {}
@@ -143,6 +197,7 @@ PD_REGISTER_PLUGIN_KERNEL(gather_nd_grad,
143197
ALL_LAYOUT,
144198
custom_kernel::GatherNdGradKernel,
145199
int64_t,
200+
int32_t,
146201
float,
147202
phi::dtype::float16,
148203
phi::dtype::bfloat16) {}

backends/npu/kernels/index_sample_kernel.cc

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ void ConcatKernel(const Context& dev_ctx,
2323
const phi::Scalar& axis_scalar,
2424
phi::DenseTensor* out);
2525

26+
template <typename T, typename Context>
27+
void GatherNdKernel(const Context& dev_ctx,
28+
const phi::DenseTensor& x,
29+
const phi::DenseTensor& index,
30+
phi::DenseTensor* out);
31+
2632
template <typename T, typename Context>
2733
void IndexSampleGather(const Context& dev_ctx,
2834
const phi::DenseTensor* index,
@@ -150,12 +156,20 @@ void IndexSampleGather(const Context& dev_ctx,
150156
TensorFromVector(dev_ctx, gather_index_vec, dev_ctx, &gather_index);
151157
gather_index.Resize({batch_size, index_length, 2});
152158

153-
NpuOpRunner runner;
154-
runner.SetType("GatherNd")
155-
.AddInput(*input)
156-
.AddInput(gather_index)
157-
.AddOutput(*out);
158-
runner.Run(dev_ctx.stream());
159+
auto dtype = input->dtype();
160+
if (dtype == phi::DataType::FLOAT32) {
161+
custom_kernel::GatherNdKernel<float, Context>(
162+
dev_ctx, *input, gather_index, out);
163+
} else if (dtype == phi::DataType::INT32) {
164+
custom_kernel::GatherNdKernel<int32_t, Context>(
165+
dev_ctx, *input, gather_index, out);
166+
} else if (dtype == phi::DataType::INT64) {
167+
custom_kernel::GatherNdKernel<int64_t, Context>(
168+
dev_ctx, *input, gather_index, out);
169+
} else if (dtype == phi::DataType::FLOAT16) {
170+
custom_kernel::GatherNdKernel<phi::dtype::float16, Context>(
171+
dev_ctx, *input, gather_index, out);
172+
}
159173
}
160174
}
161175

backends/npu/kernels/scatter_nd_add_kernel.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
namespace custom_kernel {
1919

20+
template <typename T, typename Context>
21+
void GatherNdKernel(const Context& dev_ctx,
22+
const phi::DenseTensor& x,
23+
const phi::DenseTensor& index,
24+
phi::DenseTensor* out);
25+
2026
template <typename T, typename Context>
2127
void ScatterNdAddKernel(const Context& dev_ctx,
2228
const phi::DenseTensor& x,
@@ -90,9 +96,9 @@ void ScatterNdAddGradKernel(const Context& dev_ctx,
9096
}
9197
if (updates_grad) {
9298
dev_ctx.template Alloc<T>(updates_grad);
93-
const auto& runner =
94-
NpuOpRunner("GatherNd", {out_grad, index}, {*updates_grad}, {});
95-
runner.Run(dev_ctx.stream());
99+
100+
custom_kernel::GatherNdKernel<T, Context>(
101+
dev_ctx, out_grad, index, updates_grad);
96102
}
97103
}
98104

0 commit comments

Comments
 (0)