Skip to content

Commit 43ad0b1

Browse files
authored
Fix the bug where the device memory address appears in abs_grad kernel fallback to CPU. test=kunlun (PaddlePaddle#47186)
1 parent 340009d commit 43ad0b1

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

paddle/phi/api/lib/api_custom_impl.cc

+11-1
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
9999
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
100100
} else {
101101
std::vector<const phi::TensorBase*> input_x(x.size());
102+
std::vector<std::shared_ptr<phi::DenseTensor>> temp_dense_tensots;
103+
temp_dense_tensots.reserve(x.size());
102104
for (size_t i = 0; i < input_x.size(); ++i) {
103-
input_x[i] = x[i].impl().get();
105+
if (phi::DenseTensor::classof(x[i].impl().get())) {
106+
temp_dense_tensots.push_back(PrepareData(x[i], kernel.InputAt(0), {}));
107+
input_x[i] = temp_dense_tensots.back().get();
108+
} else {
109+
input_x[i] = x[i].impl().get();
110+
}
104111
}
105112
auto x_meta_vec = MakeMetaTensor(input_x);
106113
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
@@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
118125
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
119126

120127
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
128+
if (kernel_result.has_fallback_cpu) {
129+
TransDataBackend(kernel_out, kernel_backend, kernel_out);
130+
}
121131
}
122132

123133
return api_output;

paddle/phi/core/kernel_factory.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
145145
kernel_key,
146146
kernel_name));
147147

148-
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
149148
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
150-
|| paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name))
151-
149+
VLOG(6) << "fluid_op_name: " << TransToFluidOpName(kernel_name);
150+
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end()) ||
151+
paddle::platform::is_in_xpu_black_list(TransToFluidOpName(kernel_name))
152+
#else
153+
if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second.end())
152154
#endif
153155
) {
154156
// Fallback CPU backend

0 commit comments

Comments
 (0)