Skip to content

Commit 3075229

Browse files
committed
Fix the issue with determining whether operators in control flow are registered.
1 parent 040c051 commit 3075229

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

paddle2onnx/mapper/exporter.cc

+3-5
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,12 @@ bool ModelExporter::IsOpsRegistered(const PaddlePirParser& pir_parser,
4141
bool enable_experimental_op) {
4242
OnnxHelper temp_helper;
4343
std::set<std::string> unsupported_ops;
44-
for (auto op : pir_parser.global_blocks_ops) {
44+
for (auto op : pir_parser.total_blocks_ops) {
4545
if (op->name() == "pd_op.data" || op->name() == "pd_op.fetch") {
4646
continue;
4747
}
48-
if (op->name() == "pd_op.if") {
49-
continue;
50-
}
51-
if (op->name() == "pd_op.while") {
48+
if (op->name() == "pd_op.if" || op->name() == "pd_op.while" ||
49+
op->name() == "cf.yield") {
5250
continue;
5351
}
5452
std::string op_name = convert_pir_op_name(op->name());

paddle2onnx/parser/pir_parser.cc

+21
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "paddle/pir/include/core/builtin_op.h"
3333
#include "paddle/pir/include/core/ir_context.h"
3434
#include "paddle2onnx/proto/p2o_paddle.pb.h"
35+
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
3536

3637
phi::DataType TransToPhiDataType(pir::Type dtype) {
3738
if (dtype.isa<pir::BFloat16Type>()) {
@@ -211,6 +212,25 @@ void PaddlePirParser::GetOpArgNameMappings() {
211212
}
212213
}
213214

215+
void PaddlePirParser::GetAllBlocksOpsSet(pir::Block* block) {
216+
for(auto &op : block->ops()) {
217+
std::string op_name = op->name();
218+
if(op_name != "builtin.parameter") {
219+
total_blocks_ops.insert(op);
220+
if(op_name == "pd_op.if") {
221+
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
222+
pir::Block& true_block = if_op.true_block();
223+
GetAllBlocksOpsSet(&true_block);
224+
pir::Block& false_block = if_op.false_block();
225+
GetAllBlocksOpsSet(&false_block);
226+
} else if(op_name == "pd_op.while") {
227+
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
228+
GetAllBlocksOpsSet(&while_op.body());
229+
}
230+
}
231+
}
232+
}
233+
214234
std::string PaddlePirParser::GetOpArgName(int64_t op_id,
215235
std::string name,
216236
bool if_in_sub_block) const {
@@ -436,6 +456,7 @@ bool PaddlePirParser::Init(const std::string& _model,
436456
GetGlobalBlockInputOutputInfo();
437457
GetAllOpOutputName();
438458
GetOpArgNameMappings();
459+
GetAllBlocksOpsSet(pir_program_->block());
439460
return true;
440461
}
441462
int PaddlePirParser::NumOfBlocks() const {

paddle2onnx/parser/pir_parser.h

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class PaddlePirParser {
3737
// recoring set of operators for sub block
3838
mutable std::vector<pir::Operation*>
3939
sub_blocks_ops; // todo(wangmingkai02): delete sub_blocks_ops
40+
// recoring set of operators for all blocks
41+
std::set<pir::Operation*> total_blocks_ops;
4042
// recording args of while op body name info
4143
std::unordered_map<pir::detail::ValueImpl*, pir::detail::ValueImpl*>
4244
while_op_input_value_map;
@@ -270,6 +272,7 @@ class PaddlePirParser {
270272
bool LoadParams(const std::string& path);
271273
bool GetParamValueName(std::vector<std::string>* var_names);
272274
void GetGlobalBlocksOps();
275+
void GetAllBlocksOpsSet(pir::Block *block);
273276
void GetGlobalBlockInputOutputInfo();
274277
void GetGlobalBlockInputValueName();
275278
void GetGlobalBlockOutputValueName();

0 commit comments

Comments
 (0)