Skip to content

Commit 9c2dae1

Browse files
authored
Fix output vector type bug (#54865)
* add fetch kernel * support fetch var in new ir * fix bug * polish code * change array equal to np.testing * support feed in new ir * fix bug * try to hack combine op * add scope guard * revert atan2 op * polish code * fix vector type bug * modify feed data type
1 parent 99c593b commit 9c2dae1

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ phi::KernelKey GetKernelKey(
4040
const phi::Place& place,
4141
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair) {
4242
if (op->name() == "pd.feed") {
43-
return {phi::Backend::CPU, phi::DataLayout::ANY, phi::DataType::FLOAT32};
43+
// NOTE, for now feed op don't need a kernel, so the data type from Op
44+
// Result the next op use base program datatype
45+
return {phi::Backend::CPU,
46+
phi::DataLayout::ANY,
47+
TransToPhiDataType(
48+
op->result(0).type().dyn_cast<DenseTensorType>().dtype())};
4449
}
4550
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
4651
phi::DataLayout kernel_layout = phi::DataLayout::UNDEFINED;
@@ -223,23 +228,27 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
223228
result_type.dyn_cast<dialect::DenseTensorType>());
224229
op_output_types.push_back(allocated_dense_tensor_dtype);
225230
} else if (result_type.isa<ir::VectorType>()) {
226-
auto pos1 = result_type.dyn_cast<ir::VectorType>().data()[0];
227-
228-
if (pos1.isa<dialect::DenseTensorType>()) {
229-
auto allocated_dense_tensor_dtype =
230-
paddle::dialect::AllocatedDenseTensorType::get(
231-
ctx,
232-
phi::TransToPhiPlace(kernel_key.backend()),
233-
pos1.dyn_cast<dialect::DenseTensorType>());
234-
op_output_types.push_back(allocated_dense_tensor_dtype);
235-
} else {
236-
PADDLE_THROW(phi::errors::Unimplemented(
237-
"only support dense tensor in vector type for now"));
231+
std::vector<ir::Type> vec_inner_types;
232+
auto base_types = result_type.dyn_cast<ir::VectorType>().data();
233+
for (size_t j = 0; j < base_types.size(); ++j) {
234+
if (base_types[j].isa<dialect::DenseTensorType>()) {
235+
auto allocated_dense_tensor_dtype =
236+
paddle::dialect::AllocatedDenseTensorType::get(
237+
ctx,
238+
phi::TransToPhiPlace(kernel_key.backend()),
239+
base_types[j].dyn_cast<dialect::DenseTensorType>());
240+
vec_inner_types.push_back(allocated_dense_tensor_dtype);
241+
} else {
242+
PADDLE_THROW(phi::errors::Unimplemented(
243+
"only support dense tensor in vector type for now"));
244+
}
238245
}
239246

240-
ir::Type t1 = ir::VectorType::get(ctx, op_output_types);
241-
op_output_types.clear();
247+
ir::Type t1 = ir::VectorType::get(ctx, vec_inner_types);
242248
op_output_types.push_back(t1);
249+
} else {
250+
PADDLE_THROW(phi::errors::Unimplemented(
251+
"Result type only support DenseTensorType and VectorType"));
243252
}
244253
}
245254
}

0 commit comments

Comments
 (0)