Skip to content

Commit def8dc1

Browse files
[NPU] concat_grad to aclnn. (PaddlePaddle#1286)
1 parent ce52fc8 commit def8dc1

File tree

1 file changed

+53
-6
lines changed

1 file changed

+53
-6
lines changed

backends/npu/kernels/concat_kernel.cc

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ void ConcatKernel(const Context& dev_ctx,
145145
}
146146

147147
template <typename T, typename Context>
148-
void ConcatGradKernel(const Context& dev_ctx,
149-
const std::vector<const phi::DenseTensor*>& ins,
150-
const phi::DenseTensor& dout,
151-
const phi::Scalar& axis_scalar,
152-
std::vector<phi::DenseTensor*> outs) {
148+
void AclopConcatGradKernel(const Context& dev_ctx,
149+
const std::vector<const phi::DenseTensor*>& ins,
150+
const phi::DenseTensor& dout,
151+
const phi::Scalar& axis_scalar,
152+
std::vector<phi::DenseTensor*> outs) {
153153
auto stream = dev_ctx.stream();
154154

155155
int axis = axis_scalar.to<int>();
@@ -186,6 +186,54 @@ void ConcatGradKernel(const Context& dev_ctx,
186186
}
187187
}
188188

189+
template <typename T, typename Context>
190+
void ConcatGradKernel(const Context& dev_ctx,
191+
const std::vector<const phi::DenseTensor*>& ins,
192+
const phi::DenseTensor& dout,
193+
const phi::Scalar& axis_scalar,
194+
std::vector<phi::DenseTensor*> outs) {
195+
DO_COMPATIBILITY(aclnnSliceV2,
196+
(custom_kernel::AclopConcatGradKernel<T, Context>(
197+
dev_ctx, ins, dout, axis_scalar, outs)));
198+
auto stream = dev_ctx.stream();
199+
200+
int axis = axis_scalar.to<int>();
201+
axis = ComputeAxis(static_cast<int64_t>(axis),
202+
static_cast<int64_t>(ins[0]->dims().size()));
203+
204+
std::vector<int64_t> axes_t;
205+
axes_t.push_back(axis);
206+
207+
int offset = 0;
208+
for (size_t j = 0; j < outs.size(); ++j) {
209+
if (outs[j] && outs[j]->numel() != 0UL) {
210+
dev_ctx.template Alloc<T>(outs[j]);
211+
212+
std::vector<int64_t> starts_array;
213+
starts_array.push_back(offset);
214+
std::vector<int64_t> ends_array;
215+
ends_array.push_back(ins[j]->dims()[axis] + offset);
216+
217+
std::vector<int64_t> steps;
218+
for (int i = 0; i < outs[j]->dims().size(); i++) {
219+
steps.push_back(1.0);
220+
}
221+
222+
EXEC_NPU_CMD(aclnnSliceV2,
223+
dev_ctx,
224+
dout,
225+
starts_array,
226+
ends_array,
227+
axes_t,
228+
steps,
229+
*outs[j]);
230+
}
231+
if (ins[j]->numel() != 0UL) {
232+
offset += ins[j]->dims()[axis];
233+
}
234+
}
235+
}
236+
189237
} // namespace custom_kernel
190238

191239
PD_REGISTER_PLUGIN_KERNEL(concat,
@@ -207,6 +255,5 @@ PD_REGISTER_PLUGIN_KERNEL(concat_grad,
207255
int,
208256
int64_t,
209257
float,
210-
double,
211258
phi::dtype::float16,
212259
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)