From 0d6c82e5a678d8f59a492356b1e3342e3afc04e5 Mon Sep 17 00:00:00 2001 From: root <1009134431@qq.com> Date: Tue, 18 Oct 2022 08:47:31 +0000 Subject: [PATCH] Fix the bug where the device memory address appears in abs_grad kernel fallback to CPU. test=kunlun --- paddle/phi/api/lib/api_custom_impl.cc | 12 +++++++++++- paddle/phi/core/kernel_factory.cc | 8 +++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 2fa40320e55830..19a9b808dd6f6d 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector& x) { (*kernel_fn)(*dev_ctx, input_x, kernel_out); } else { std::vector input_x(x.size()); + std::vector> temp_dense_tensots; + temp_dense_tensots.reserve(x.size()); for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x[i].impl().get(); + if (phi::DenseTensor::classof(x[i].impl().get())) { + temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {})); + input_x[i] = temp_dense_tensots.back().get(); + } else { + input_x[i] = x[i].impl().get(); + } } auto x_meta_vec = MakeMetaTensor(input_x); std::vector x_metas(x_meta_vec.size()); @@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, kernel_out); + if (kernel_result.has_fallback_cpu) { + TransDataBackend(kernel_out, kernel_backend, kernel_out); + } } return api_output; diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 480882550dbcad..bbfe10591f0f92 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError( kernel_key, kernel_name)); - if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP) - || paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) - + VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name); + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) || + paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name)) +#else + if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) #endif ) { // Fallback CPU backend