File tree 2 files changed +16
-4
lines changed
2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -99,8 +99,15 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
99
99
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
100
100
} else {
101
101
std::vector<const phi::TensorBase*> input_x (x.size ());
102
+ std::vector<std::shared_ptr<phi::DenseTensor>> temp_dense_tensots;
103
+ temp_dense_tensots.reserve (x.size ());
102
104
for (size_t i = 0 ; i < input_x.size (); ++i) {
103
- input_x[i] = x[i].impl ().get ();
105
+ if (phi::DenseTensor::classof (x[i].impl ().get ())) {
106
+ temp_dense_tensots.push_back (PrepareData (x[i], kernel.InputAt (0 ), {}));
107
+ input_x[i] = temp_dense_tensots.back ().get ();
108
+ } else {
109
+ input_x[i] = x[i].impl ().get ();
110
+ }
104
111
}
105
112
auto x_meta_vec = MakeMetaTensor (input_x);
106
113
std::vector<const phi::MetaTensor*> x_metas (x_meta_vec.size ());
@@ -118,6 +125,9 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
118
125
auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
119
126
120
127
(*kernel_fn)(*dev_ctx, input_x, kernel_out);
128
+ if (kernel_result.has_fallback_cpu ) {
129
+ TransDataBackend (kernel_out, kernel_backend, kernel_out);
130
+ }
121
131
}
122
132
123
133
return api_output;
Original file line number Diff line number Diff line change @@ -145,10 +145,12 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
145
145
kernel_key,
146
146
kernel_name));
147
147
148
- if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second .end ())
149
148
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
150
- || paddle::platform::is_in_xpu_black_list (TransToFluidOpName (kernel_name))
151
-
149
+ VLOG (6 ) << " fluid_op_name: " << TransToFluidOpName (kernel_name);
150
+ if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second .end ()) ||
151
+ paddle::platform::is_in_xpu_black_list (TransToFluidOpName (kernel_name))
152
+ #else
153
+ if ((FLAGS_enable_api_kernel_fallback && kernel_iter == iter->second .end ())
152
154
#endif
153
155
) {
154
156
// Fallback CPU backend
You can’t perform that action at this time.
0 commit comments