Skip to content

Commit d438673

Browse files
authored
fix interpolate kernel to register to all backend (PaddlePaddle#232)
1 parent 42e50dc commit d438673

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

backends/npu/kernels/interpolate_kernel.cc

100755100644
Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -550,10 +550,10 @@ void InterpolateKernel(
550550
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
551551
if (size_tensor && size_tensor->size() > 0) {
552552
auto list_new_shape_tensor = size_tensor.get();
553-
std::vector<int32_t> output_h(1);
554-
std::vector<int32_t> output_w(1);
555-
TensorToVector(dev_ctx, *(list_new_shape_tensor[0]), dev_ctx, &output_h);
556-
TensorToVector(dev_ctx, *(list_new_shape_tensor[1]), dev_ctx, &output_w);
553+
auto output_h =
554+
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[0]);
555+
auto output_w =
556+
get_new_data_from_tensor<int>(dev_ctx, list_new_shape_tensor[1]);
557557
out_h = output_h[0];
558558
out_w = output_w[0];
559559
} else if (out_size) {
@@ -933,25 +933,37 @@ PD_REGISTER_PLUGIN_KERNEL(nearest_interp,
933933
ALL_LAYOUT,
934934
custom_kernel::NearestInterpKernel,
935935
float,
936-
phi::dtype::float16) {}
936+
phi::dtype::float16) {
937+
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
938+
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
939+
}
937940

938941
PD_REGISTER_PLUGIN_KERNEL(nearest_interp_grad,
939942
npu,
940943
ALL_LAYOUT,
941944
custom_kernel::NearestInterpGradKernel,
942945
float,
943-
phi::dtype::float16) {}
946+
phi::dtype::float16) {
947+
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
948+
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
949+
}
944950

945951
PD_REGISTER_PLUGIN_KERNEL(bilinear_interp,
946952
npu,
947953
ALL_LAYOUT,
948954
custom_kernel::BilinearInterpKernel,
949955
float,
950-
phi::dtype::float16) {}
956+
phi::dtype::float16) {
957+
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
958+
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
959+
}
951960

952961
PD_REGISTER_PLUGIN_KERNEL(bilinear_interp_grad,
953962
npu,
954963
ALL_LAYOUT,
955964
custom_kernel::BilinearInterpGradKernel,
956965
float,
957-
phi::dtype::float16) {}
966+
phi::dtype::float16) {
967+
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
968+
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
969+
}

0 commit comments

Comments
 (0)