@@ -550,10 +550,10 @@ void InterpolateKernel(
550
550
// Priority: SizeTensor > OutSize > Scale > scale > out_h & out_w
551
551
if (size_tensor && size_tensor->size () > 0 ) {
552
552
auto list_new_shape_tensor = size_tensor.get ();
553
- std::vector< int32_t > output_h ( 1 );
554
- std::vector< int32_t > output_w ( 1 );
555
- TensorToVector (dev_ctx, *(list_new_shape_tensor[ 0 ]), dev_ctx, &output_h);
556
- TensorToVector (dev_ctx, *( list_new_shape_tensor[1 ]), dev_ctx, &output_w );
553
+ auto output_h =
554
+ get_new_data_from_tensor< int >(dev_ctx, list_new_shape_tensor[ 0 ] );
555
+ auto output_w =
556
+ get_new_data_from_tensor< int > (dev_ctx, list_new_shape_tensor[1 ]);
557
557
out_h = output_h[0 ];
558
558
out_w = output_w[0 ];
559
559
} else if (out_size) {
@@ -933,25 +933,37 @@ PD_REGISTER_PLUGIN_KERNEL(nearest_interp,
933
933
ALL_LAYOUT,
934
934
custom_kernel::NearestInterpKernel,
935
935
float ,
936
- phi::dtype::float16) {}
936
+ phi::dtype::float16) {
937
+ kernel->InputAt (2 ).SetBackend (phi::Backend::ALL_BACKEND);
938
+ kernel->InputAt (3 ).SetBackend (phi::Backend::ALL_BACKEND);
939
+ }
937
940
938
941
PD_REGISTER_PLUGIN_KERNEL (nearest_interp_grad,
939
942
npu,
940
943
ALL_LAYOUT,
941
944
custom_kernel::NearestInterpGradKernel,
942
945
float ,
943
- phi::dtype::float16) {}
946
+ phi::dtype::float16) {
947
+ kernel->InputAt (2 ).SetBackend (phi::Backend::ALL_BACKEND);
948
+ kernel->InputAt (3 ).SetBackend (phi::Backend::ALL_BACKEND);
949
+ }
944
950
945
951
PD_REGISTER_PLUGIN_KERNEL (bilinear_interp,
946
952
npu,
947
953
ALL_LAYOUT,
948
954
custom_kernel::BilinearInterpKernel,
949
955
float ,
950
- phi::dtype::float16) {}
956
+ phi::dtype::float16) {
957
+ kernel->InputAt (2 ).SetBackend (phi::Backend::ALL_BACKEND);
958
+ kernel->InputAt (3 ).SetBackend (phi::Backend::ALL_BACKEND);
959
+ }
951
960
952
961
PD_REGISTER_PLUGIN_KERNEL (bilinear_interp_grad,
953
962
npu,
954
963
ALL_LAYOUT,
955
964
custom_kernel::BilinearInterpGradKernel,
956
965
float ,
957
- phi::dtype::float16) {}
966
+ phi::dtype::float16) {
967
+ kernel->InputAt (2 ).SetBackend (phi::Backend::ALL_BACKEND);
968
+ kernel->InputAt (3 ).SetBackend (phi::Backend::ALL_BACKEND);
969
+ }
0 commit comments