@@ -145,11 +145,11 @@ void ConcatKernel(const Context& dev_ctx,
145
145
}
146
146
147
147
template <typename T, typename Context>
148
- void ConcatGradKernel (const Context& dev_ctx,
149
- const std::vector<const phi::DenseTensor*>& ins,
150
- const phi::DenseTensor& dout,
151
- const phi::Scalar& axis_scalar,
152
- std::vector<phi::DenseTensor*> outs) {
148
+ void AclopConcatGradKernel (const Context& dev_ctx,
149
+ const std::vector<const phi::DenseTensor*>& ins,
150
+ const phi::DenseTensor& dout,
151
+ const phi::Scalar& axis_scalar,
152
+ std::vector<phi::DenseTensor*> outs) {
153
153
auto stream = dev_ctx.stream ();
154
154
155
155
int axis = axis_scalar.to <int >();
@@ -186,6 +186,54 @@ void ConcatGradKernel(const Context& dev_ctx,
186
186
}
187
187
}
188
188
189
+ template <typename T, typename Context>
190
+ void ConcatGradKernel (const Context& dev_ctx,
191
+ const std::vector<const phi::DenseTensor*>& ins,
192
+ const phi::DenseTensor& dout,
193
+ const phi::Scalar& axis_scalar,
194
+ std::vector<phi::DenseTensor*> outs) {
195
+ DO_COMPATIBILITY (aclnnSliceV2,
196
+ (custom_kernel::AclopConcatGradKernel<T, Context>(
197
+ dev_ctx, ins, dout, axis_scalar, outs)));
198
+ auto stream = dev_ctx.stream ();
199
+
200
+ int axis = axis_scalar.to <int >();
201
+ axis = ComputeAxis (static_cast <int64_t >(axis),
202
+ static_cast <int64_t >(ins[0 ]->dims ().size ()));
203
+
204
+ std::vector<int64_t > axes_t ;
205
+ axes_t .push_back (axis);
206
+
207
+ int offset = 0 ;
208
+ for (size_t j = 0 ; j < outs.size (); ++j) {
209
+ if (outs[j] && outs[j]->numel () != 0UL ) {
210
+ dev_ctx.template Alloc <T>(outs[j]);
211
+
212
+ std::vector<int64_t > starts_array;
213
+ starts_array.push_back (offset);
214
+ std::vector<int64_t > ends_array;
215
+ ends_array.push_back (ins[j]->dims ()[axis] + offset);
216
+
217
+ std::vector<int64_t > steps;
218
+ for (int i = 0 ; i < outs[j]->dims ().size (); i++) {
219
+ steps.push_back (1.0 );
220
+ }
221
+
222
+ EXEC_NPU_CMD (aclnnSliceV2,
223
+ dev_ctx,
224
+ dout,
225
+ starts_array,
226
+ ends_array,
227
+ axes_t ,
228
+ steps,
229
+ *outs[j]);
230
+ }
231
+ if (ins[j]->numel () != 0UL ) {
232
+ offset += ins[j]->dims ()[axis];
233
+ }
234
+ }
235
+ }
236
+
189
237
} // namespace custom_kernel
190
238
191
239
PD_REGISTER_PLUGIN_KERNEL (concat,
@@ -207,6 +255,5 @@ PD_REGISTER_PLUGIN_KERNEL(concat_grad,
207
255
int ,
208
256
int64_t ,
209
257
float ,
210
- double ,
211
258
phi::dtype::float16,
212
259
phi::dtype::bfloat16) {}
0 commit comments