Skip to content

Commit b72edb8

Browse files
[NPU] improve bilinear_interp speed. (PaddlePaddle#1287)
1 parent def8dc1 commit b72edb8

File tree

1 file changed

+186
-19
lines changed

1 file changed

+186
-19
lines changed

backends/npu/kernels/interpolate_kernel.cc

Lines changed: 186 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,13 @@ struct InterpolateFunction {
7575
phi::DenseTensor* y,
7676
const std::vector<int>& dim,
7777
bool keep_dims = true) {
78-
const auto& runner = NpuOpRunner(
79-
"ReduceSumD", {*x}, {*y}, {{"axes", dim}, {"keep_dims", keep_dims}});
80-
runner.Run(stream);
78+
NpuOpRunner runner;
79+
runner.SetType("ReduceSum")
80+
.AddInput(*x)
81+
.AddInput(dev_ctx, std::move(dim))
82+
.AddOutput(*y)
83+
.AddAttrs({{"keep_dims", keep_dims}})
84+
.Run(dev_ctx.stream());
8185
}
8286
void Add(const phi::DenseTensor* x,
8387
const phi::DenseTensor* y,
@@ -157,37 +161,42 @@ struct InterpolateFunction {
157161
phi::DenseTensor gy_t;
158162
gy_t.Resize(y_new_shape);
159163
dev_ctx.template Alloc<T>(&gy_t);
160-
Transpose(gy, &gy_t, axis_swap);
164+
Transpose(dev_ctx, gy, &gy_t, axis_swap);
165+
161166
// 2 scatter
162167
auto x_new_shape = gx->dims();
163168
auto xt = x_new_shape[axis];
164169
x_new_shape[axis] = x_new_shape[0];
165170
x_new_shape[0] = xt;
166-
phi::DenseTensor gx_zero, gx_t;
171+
phi::DenseTensor gx_zero;
167172
gx_zero.Resize(x_new_shape);
168-
gx_t.Resize(x_new_shape);
169173
dev_ctx.template Alloc<T>(&gx_zero);
170-
dev_ctx.template Alloc<T>(&gx_t);
171174
FillNpuTensorWithConstant<T>(&gx_zero, dev_ctx, static_cast<T>(0));
172175
gx_zero.Resize(x_new_shape);
173-
Scatter(&gx_zero, indices, &gy_t, &gx_t);
176+
Scatter(dev_ctx, &gx_zero, indices, &gy_t, &gx_zero);
177+
174178
// 3 gx swapaxis: axis, 0
175-
Transpose(&gx_t, gx, axis_swap);
179+
Transpose(dev_ctx, &gx_zero, gx, axis_swap);
176180
}
177-
void Scatter(const phi::DenseTensor* x,
181+
void Scatter(const Context& dev_ctx,
182+
const phi::DenseTensor* x,
178183
const phi::DenseTensor* index,
179184
const phi::DenseTensor* updates,
180185
phi::DenseTensor* y) {
181186
const auto& runner =
182-
NpuOpRunner("TensorScatterAdd", {*x, *index, *updates}, {*y}, {});
187+
NpuOpRunner("ScatterNdAdd", {*x, *index, *updates}, {*y}, {});
183188
runner.Run(stream);
184189
}
185-
void Transpose(const phi::DenseTensor* x,
190+
void Transpose(const Context& dev_ctx,
191+
const phi::DenseTensor* x,
186192
phi::DenseTensor* y,
187193
const std::vector<int>& axis) {
188-
const auto& runner =
189-
NpuOpRunner("TransposeD", {*x}, {*y}, {{"perm", axis}});
190-
runner.Run(stream);
194+
NpuOpRunner runner;
195+
runner.SetType("Transpose")
196+
.AddInput(*x)
197+
.AddInput(dev_ctx, std::move(axis))
198+
.AddOutput(*y)
199+
.Run(dev_ctx.stream());
191200
}
192201
void Muls(const phi::DenseTensor* x, float scalar, phi::DenseTensor* y) {
193202
const auto& runner = NpuOpRunner("Muls", {*x}, {*y}, {{"value", scalar}});
@@ -350,6 +359,32 @@ void BilinearParamTensorCompute(const Context& dev_ctx,
350359
}
351360
}
352361

362+
template <typename T, typename Context>
363+
void BilinearFwdAclnn(const Context& dev_ctx,
364+
const phi::DenseTensor* input,
365+
phi::DenseTensor* output,
366+
const float scale_h,
367+
const float scale_w,
368+
const bool align_corners,
369+
const int align_mode,
370+
const phi::DataLayout& data_layout) {
371+
auto outdim = output->dims();
372+
double h = 1;
373+
double w = 1;
374+
std::vector<int64_t> outsize;
375+
for (int i = 2; i < outdim.size(); i++) {
376+
outsize.push_back(outdim[i]);
377+
}
378+
EXEC_NPU_CMD(aclnnUpsampleBilinear2d,
379+
dev_ctx,
380+
*input,
381+
outsize,
382+
align_corners,
383+
h,
384+
w,
385+
*output);
386+
}
387+
353388
template <typename T, typename Context>
354389
void BilinearFwdNpu(const Context& dev_ctx,
355390
const phi::DenseTensor* input,
@@ -359,11 +394,60 @@ void BilinearFwdNpu(const Context& dev_ctx,
359394
const bool align_corners,
360395
const int align_mode,
361396
const phi::DataLayout& data_layout) {
362-
InterpolateFunction<T, Context> F(dev_ctx);
363-
auto place = dev_ctx.GetPlace();
364397
auto outdim = output->dims();
365398
auto indim = input->dims();
366399

400+
if (data_layout == phi::DataLayout::NCHW) {
401+
int indim_h = indim[2];
402+
int indim_w = indim[3];
403+
int outdim_h = outdim[2];
404+
int outdim_w = outdim[3];
405+
406+
bool aclnn_flag = false;
407+
if (scale_h == -1 && scale_w == -1) {
408+
aclnn_flag = true;
409+
} else if (static_cast<float>(outdim_h) / static_cast<float>(indim_h) ==
410+
scale_h &&
411+
static_cast<float>(outdim_w) / static_cast<float>(indim_w) ==
412+
scale_w) {
413+
aclnn_flag = true;
414+
}
415+
416+
if (align_corners == true) {
417+
return custom_kernel::BilinearFwdAclnn<T, Context>(dev_ctx,
418+
input,
419+
output,
420+
scale_h,
421+
scale_w,
422+
align_corners,
423+
align_mode,
424+
data_layout);
425+
} else if (aclnn_flag && align_mode == 0) {
426+
if (outdim_h != 1 && outdim_w != 1) {
427+
return custom_kernel::BilinearFwdAclnn<T, Context>(dev_ctx,
428+
input,
429+
output,
430+
scale_h,
431+
scale_w,
432+
align_corners,
433+
align_mode,
434+
data_layout);
435+
} else if (outdim_h == 1 && outdim_w == 1) {
436+
return custom_kernel::BilinearFwdAclnn<T, Context>(dev_ctx,
437+
input,
438+
output,
439+
scale_h,
440+
scale_w,
441+
true,
442+
align_mode,
443+
data_layout);
444+
}
445+
}
446+
}
447+
448+
InterpolateFunction<T, Context> F(dev_ctx);
449+
auto place = dev_ctx.GetPlace();
450+
367451
int axis_h, axis_w;
368452
int out_h, out_w, in_h, in_w;
369453
float ratio_h, ratio_w;
@@ -461,6 +545,40 @@ void BilinearFwdNpu(const Context& dev_ctx,
461545
F.ReduceSum(&out_x4, output, std::vector<int>{0}, false);
462546
}
463547

548+
template <typename T, typename Context>
549+
void BilinearBwdAclnn(const Context& dev_ctx,
550+
const phi::DenseTensor* gout,
551+
phi::DenseTensor* gin,
552+
const float scale_h,
553+
const float scale_w,
554+
const bool align_corners,
555+
const int align_mode,
556+
const phi::DataLayout& data_layout) {
557+
auto indim = gin->dims();
558+
auto outdim = gout->dims();
559+
double h = 1;
560+
double w = 1;
561+
562+
std::vector<int64_t> outputsize;
563+
for (int i = 2; i < outdim.size(); i++) {
564+
outputsize.push_back(outdim[i]);
565+
}
566+
std::vector<int64_t> inputsize;
567+
for (int i = 0; i < indim.size(); i++) {
568+
inputsize.push_back(indim[i]);
569+
}
570+
// dev_ctx.template Alloc<T>(gin);
571+
EXEC_NPU_CMD(aclnnUpsampleBilinear2dBackward,
572+
dev_ctx,
573+
*gout,
574+
outputsize,
575+
inputsize,
576+
align_corners,
577+
h,
578+
w,
579+
*gin);
580+
}
581+
464582
template <typename T, typename Context>
465583
void BilinearBwdNpu(const Context& dev_ctx,
466584
const phi::DenseTensor* gout,
@@ -470,11 +588,60 @@ void BilinearBwdNpu(const Context& dev_ctx,
470588
const bool align_corners,
471589
const int align_mode,
472590
const phi::DataLayout& data_layout) {
473-
InterpolateFunction<T, Context> F(dev_ctx);
474-
auto place = dev_ctx.GetPlace();
475591
auto outdim = gout->dims();
476592
auto indim = gin->dims();
477593

594+
if (data_layout == phi::DataLayout::NCHW) {
595+
int indim_h = indim[2];
596+
int indim_w = indim[3];
597+
int outdim_h = outdim[2];
598+
int outdim_w = outdim[3];
599+
600+
bool aclnn_flag = false;
601+
if (scale_h == -1 && scale_w == -1) {
602+
aclnn_flag = true;
603+
} else if (static_cast<float>(outdim_h) / static_cast<float>(indim_h) ==
604+
scale_h &&
605+
static_cast<float>(outdim_w) / static_cast<float>(indim_w) ==
606+
scale_w) {
607+
aclnn_flag = true;
608+
}
609+
610+
if (align_corners == true) {
611+
return custom_kernel::BilinearBwdAclnn<T, Context>(dev_ctx,
612+
gout,
613+
gin,
614+
scale_h,
615+
scale_w,
616+
align_corners,
617+
align_mode,
618+
data_layout);
619+
} else if (aclnn_flag && align_mode == 0) {
620+
if (outdim_h != 1 && outdim_w != 1) {
621+
return custom_kernel::BilinearBwdAclnn<T, Context>(dev_ctx,
622+
gout,
623+
gin,
624+
scale_h,
625+
scale_w,
626+
align_corners,
627+
align_mode,
628+
data_layout);
629+
} else if (outdim_h == 1 && outdim_w == 1) {
630+
return custom_kernel::BilinearBwdAclnn<T, Context>(dev_ctx,
631+
gout,
632+
gin,
633+
scale_h,
634+
scale_w,
635+
true,
636+
align_mode,
637+
data_layout);
638+
}
639+
}
640+
}
641+
642+
InterpolateFunction<T, Context> F(dev_ctx);
643+
auto place = dev_ctx.GetPlace();
644+
478645
int axis_h, axis_w;
479646
int out_h, out_w, in_h, in_w;
480647
float ratio_h, ratio_w;

0 commit comments

Comments
 (0)