Skip to content

Commit d4041a5

Browse files
committed
fix the issue with while input values and args
1 parent 3c7fd18 commit d4041a5

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

paddle2onnx/mapper/exporter.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ int32_t ModelExporter::GetMinOpsetVersion(const PaddlePirParser& pir_parser,
252252
current_opset = current_opset > 11 ? current_opset : 11;
253253
} else if (op_name == "pd_op.while") {
254254
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
255+
pir_parser.GetWhileInputValuesAndArgsMappings(while_op);
255256
current_opset = GetCfBlockMinOpsetVersion(pir_parser, while_op.body());
256257
current_opset = current_opset > 11 ? current_opset : 11;
257258

@@ -483,7 +484,7 @@ ONNX_NAMESPACE::GraphProto ModelExporter::ExportIfBlock(
483484
temp_outputs.push_back(std::move(MakeValueInfo(cond_info[0])));
484485
if (value.defining_op() == nullptr) {
485486
value =
486-
pir::Value(pir_parser.while_op_input_value_map[&(*(value.impl()))]);
487+
pir::Value(pir_parser.while_op_values_args_map[&(*(value.impl()))]);
487488
}
488489
if (value.defining_op()->GetParent() != &block) {
489490
temp_inputs.push_back(std::move(MakeValueInfo(cond_info[0])));

paddle2onnx/mapper/while.cc

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,12 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser,
2525
std::vector<TensorInfo> outputs_info;
2626
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
2727
auto cond_info = pir_parser.GetTensorInfo(while_op.cond());
28-
// mapping args and inputs in while op using while_op_input_value_map
29-
std::vector<pir::detail::ValueImpl*> while_op_input_value_address;
30-
std::vector<pir::detail::ValueImpl*> while_op_input_arg_address;
31-
pir_parser.while_op_input_value_map
32-
.clear(); // wangmingkai02: handle nested loop situations in future.
33-
34-
// record input value address
3528
for (int index = 1; index < while_op.num_operands(); index++) {
3629
const pir::Value& value = while_op.operand_source(index);
3730
inputs_info.push_back(pir_parser.GetTensorInfo(
38-
pir_parser.GetOpOutputName(value), value.type()));
39-
while_op_input_value_address.push_back(
40-
&(*(value).impl())); // get value address
41-
}
42-
// record args value address
43-
std::vector<pir::Value> args = while_op.block_args();
44-
for (int i = 0; i < args.size(); i++) {
45-
const pir::Value& value = args[i];
46-
while_op_input_arg_address.push_back(&(*(value.impl())));
47-
}
48-
49-
// mapping
50-
for (int index = 0; index < while_op_input_value_address.size(); index++) {
51-
pir_parser.while_op_input_value_map[while_op_input_arg_address[index]] =
52-
while_op_input_value_address[index];
31+
pir_parser.GetSubBlockOpOutputName(value), value.type()));
5332
}
33+
pir_parser.GetWhileInputValuesAndArgsMappings(while_op);
5434

5535
std::vector<pir::Operation*> sub_blocks_ops_copy(pir_parser.sub_blocks_ops);
5636
pir_parser.sub_blocks_ops.clear();
@@ -120,7 +100,7 @@ void ModelExporter::ExportWhile(PaddlePirParser& pir_parser,
120100
input_names.push_back(inputs_info[i].name);
121101
}
122102
for (size_t i = 0; i < op->num_results(); i++) {
123-
output_names.push_back(pir_parser.GetOpOutputName(op->result(i)));
103+
output_names.push_back(pir_parser.GetSubBlockOpOutputName(op->result(i)));
124104
}
125105
auto loop_node = temp_helper->MakeNode("Loop", input_names, output_names);
126106
AddAttribute(loop_node, "body", graph);

paddle2onnx/parser/pir_parser.cc

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ std::string PaddlePirParser::GetOpOutputName(const pir::Value& source) const {
106106

107107
std::string PaddlePirParser::GetSubBlockOpOutputName(
108108
const pir::Value& source) const {
109-
auto it = while_op_input_value_map.find(&(*(source.impl())));
109+
auto it = while_op_values_args_map.find(&(*(source.impl())));
110110
pir::Operation* op;
111111
uint32_t output_idx;
112-
if (it != while_op_input_value_map.end()) {
112+
if (it != while_op_values_args_map.end()) {
113113
pir::Value value(it->second);
114114
op = value.defining_op();
115115
output_idx = value.dyn_cast<pir::OpResult>().index();
@@ -1029,4 +1029,34 @@ P2ODataType PaddlePirParser::TransPirDataType2OldIrDataType(
10291029
"PaddlePirParser::TransPirDataType2OnnxDataType.");
10301030
}
10311031
}
1032+
void PaddlePirParser::GetWhileInputValuesAndArgsMappings(
1033+
const paddle::dialect::WhileOp& while_op) const {
1034+
// mapping args and inputs in while op using while_op_values_args_map
1035+
std::vector<pir::detail::ValueImpl*> while_op_input_value_address;
1036+
std::vector<pir::detail::ValueImpl*> while_op_input_arg_address;
1037+
// record input value address
1038+
for (int index = 1; index < while_op.num_operands(); index++) {
1039+
const pir::Value& value = while_op.operand_source(index);
1040+
while_op_input_value_address.push_back(
1041+
&(*(value).impl())); // get value address
1042+
}
1043+
// record args value address
1044+
std::vector<pir::Value> args = while_op.block_args();
1045+
for (int i = 0; i < args.size(); i++) {
1046+
const pir::Value& value = args[i];
1047+
while_op_input_arg_address.push_back(&(*(value.impl())));
1048+
}
1049+
1050+
// mapping
1051+
for (int index = 0; index < while_op_input_value_address.size(); index++) {
1052+
auto arg_addr = while_op_input_arg_address[index];
1053+
if (while_op_values_args_map.count(arg_addr)) continue;
1054+
auto value_addr = while_op_input_value_address[index];
1055+
while (while_op_values_args_map.count(value_addr)) {
1056+
value_addr = while_op_values_args_map[value_addr];
1057+
}
1058+
while_op_values_args_map[arg_addr] = value_addr;
1059+
}
1060+
}
1061+
10321062
} // namespace paddle2onnx

paddle2onnx/parser/pir_parser.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/pir/include/core/value.h"
2323
#include "paddle2onnx/parser/tensor_utils.h"
2424
#include "paddle2onnx/proto/p2o_paddle.pb.h"
25+
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
2526
namespace paddle2onnx {
2627
class PaddlePirParser {
2728
public:
@@ -40,8 +41,8 @@ class PaddlePirParser {
4041
// recoring set of operators for all blocks
4142
std::set<pir::Operation*> total_blocks_ops;
4243
// recording args of while op body name info
43-
std::unordered_map<pir::detail::ValueImpl*, pir::detail::ValueImpl*>
44-
while_op_input_value_map;
44+
mutable std::unordered_map<pir::detail::ValueImpl*, pir::detail::ValueImpl*>
45+
while_op_values_args_map;
4546
int NumOfBlocks() const;
4647
// int NumOfOps(int block_idx) const;
4748
int NumOfProgramOps() const;
@@ -265,6 +266,8 @@ class PaddlePirParser {
265266
std::string tensor_arr_name) const;
266267
std::string GetTensorArrayName(int64_t op_id, bool if_in_sub_block) const;
267268
std::string GenOpInputOutputName(const std::string& name) const;
269+
void GetWhileInputValuesAndArgsMappings(
270+
const paddle::dialect::WhileOp& while_op) const;
268271

269272
private:
270273
bool IsAttrVar(const pir::Operation* op, const int64_t& attr_id) const;

0 commit comments

Comments
 (0)