@@ -40,7 +40,12 @@ phi::KernelKey GetKernelKey(
40
40
const phi::Place& place,
41
41
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) {
42
42
if (op->name () == " pd.feed" ) {
43
- return {phi::Backend::CPU, phi::DataLayout::ANY, phi::DataType::FLOAT32};
43
+ // NOTE, for now feed op don't need a kernel, so the data type from Op
44
+ // Result the next op use base program datatype
45
+ return {phi::Backend::CPU,
46
+ phi::DataLayout::ANY,
47
+ TransToPhiDataType (
48
+ op->result (0 ).type ().dyn_cast <DenseTensorType>().dtype ())};
44
49
}
45
50
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
46
51
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
@@ -223,23 +228,27 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
223
228
result_type.dyn_cast <dialect::DenseTensorType>());
224
229
op_output_types.push_back (allocated_dense_tensor_dtype);
225
230
} else if (result_type.isa <ir::VectorType>()) {
226
- auto pos1 = result_type.dyn_cast <ir::VectorType>().data ()[0 ];
227
-
228
- if (pos1.isa <dialect::DenseTensorType>()) {
229
- auto allocated_dense_tensor_dtype =
230
- paddle::dialect::AllocatedDenseTensorType::get (
231
- ctx,
232
- phi::TransToPhiPlace (kernel_key.backend ()),
233
- pos1.dyn_cast <dialect::DenseTensorType>());
234
- op_output_types.push_back (allocated_dense_tensor_dtype);
235
- } else {
236
- PADDLE_THROW (phi::errors::Unimplemented (
237
- " only support dense tensor in vector type for now" ));
231
+ std::vector<ir::Type> vec_inner_types;
232
+ auto base_types = result_type.dyn_cast <ir::VectorType>().data ();
233
+ for (size_t j = 0 ; j < base_types.size (); ++j) {
234
+ if (base_types[j].isa <dialect::DenseTensorType>()) {
235
+ auto allocated_dense_tensor_dtype =
236
+ paddle::dialect::AllocatedDenseTensorType::get (
237
+ ctx,
238
+ phi::TransToPhiPlace (kernel_key.backend ()),
239
+ base_types[j].dyn_cast <dialect::DenseTensorType>());
240
+ vec_inner_types.push_back (allocated_dense_tensor_dtype);
241
+ } else {
242
+ PADDLE_THROW (phi::errors::Unimplemented (
243
+ " only support dense tensor in vector type for now" ));
244
+ }
238
245
}
239
246
240
- ir::Type t1 = ir::VectorType::get (ctx, op_output_types);
241
- op_output_types.clear ();
247
+ ir::Type t1 = ir::VectorType::get (ctx, vec_inner_types);
242
248
op_output_types.push_back (t1);
249
+ } else {
250
+ PADDLE_THROW (phi::errors::Unimplemented (
251
+ " Result type only support DenseTensorType and VectorType" ));
243
252
}
244
253
}
245
254
}
0 commit comments