Skip to content

Commit 2153f91

Browse files
Support dynamic shape in repeat_interleave op (#1589)
1 parent 60e89ce commit 2153f91

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

paddle2onnx/mapper/tensor/repeat_interleave.cc

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,68 @@ int32_t RepeatInterleaveMapper::GetMinOpsetVersion(bool verbose) {
2323
return op_version;
2424
}
2525

26-
void RepeatInterleaveMapper::Opset9() {
27-
auto x_info = GetInput("X"); // shape = [1, 2, 3]
28-
auto out_info = GetOutput("Out");
26+
void RepeatInterleaveMapper::DynamicRepeatInterleave(
27+
const std::vector<TensorInfo>& x_info,
28+
const std::vector<TensorInfo>& out_info) {
29+
std::string dim_size_name = "";
30+
31+
auto shape_node = helper_->MakeNode("Shape", {x_info[0].name}, 1);
32+
auto dim_node =
33+
helper_->MakeNode("Gather",
34+
{shape_node->output(0),
35+
helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
36+
std::vector<int64_t>{dim_})},
37+
1);
38+
dim_size_name = dim_node->output(0);
39+
40+
std::string repeat_info_name = "";
41+
int64_t repeat = 0;
42+
if (in_pir_mode) {
43+
if (OpType() == "pd_op.repeat_interleave") {
44+
GetAttr("repeats", &repeat);
45+
}
46+
} else {
47+
GetAttr("Repeats", &repeat);
48+
}
49+
50+
if (HasInput("RepeatTensor")) {
51+
auto tmp_info = GetInput("RepeatTensor");
52+
repeat_info_name = helper_->AutoCast(
53+
tmp_info[0].name, tmp_info[0].dtype, P2ODataType::INT64);
54+
} else if (repeat != 0) {
55+
auto repeat_node =
56+
helper_->MakeNode("Expand",
57+
{helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
58+
std::vector<int64_t>{repeat}),
59+
dim_size_name},
60+
1);
61+
repeat_info_name = repeat_node->output(0);
62+
}
63+
64+
std::vector<std::string> split_input_names;
65+
66+
split_input_names.push_back(x_info[0].name);
67+
68+
std::vector<std::string> output_names;
69+
int x_shape_size = x_info[0].shape.size();
70+
71+
std::string prefix_name = helper_->Constant(
72+
ONNX_NAMESPACE::TensorProto::INT64, std::vector<int64_t>(dim_, 1));
73+
std::string suffix_name =
74+
helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
75+
std::vector<int64_t>(x_shape_size - dim_ - 1, 1));
76+
77+
std::string tile_name =
78+
helper_->Concat({prefix_name, repeat_info_name, suffix_name}, 0);
79+
auto node = helper_->MakeNode("Tile", {x_info[0].name, tile_name}, 1);
80+
output_names.push_back(node->output(0));
81+
82+
helper_->Concat(output_names, out_info[0].name, dim_);
83+
}
84+
85+
void RepeatInterleaveMapper::StaticRepeatInterleave(
86+
const std::vector<TensorInfo>& x_info,
87+
const std::vector<TensorInfo>& out_info) {
2988
int n = x_info[0].shape[dim_];
3089
int x_shape_size = x_info[0].shape.size();
3190

@@ -79,4 +138,23 @@ void RepeatInterleaveMapper::Opset9() {
79138
}
80139
helper_->Concat(output_names, out_info[0].name, dim_);
81140
}
141+
142+
void RepeatInterleaveMapper::Opset9() {
143+
auto x_info = GetInput("X");
144+
auto out_info = GetOutput("Out");
145+
146+
bool is_dynamic_shape = false;
147+
for (auto dim : x_info[0].shape) {
148+
if (dim == -1) {
149+
is_dynamic_shape = true;
150+
break;
151+
}
152+
}
153+
154+
if (is_dynamic_shape) {
155+
DynamicRepeatInterleave(x_info, out_info);
156+
} else {
157+
StaticRepeatInterleave(x_info, out_info);
158+
}
159+
}
82160
} // namespace paddle2onnx

paddle2onnx/mapper/tensor/repeat_interleave.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class RepeatInterleaveMapper : public Mapper {
3939

4040
void Opset9() override;
4141
int32_t GetMinOpsetVersion(bool verbose) override;
42+
void DynamicRepeatInterleave(const std::vector<TensorInfo> &x_info,
43+
const std::vector<TensorInfo> &out_info);
44+
void StaticRepeatInterleave(const std::vector<TensorInfo> &x_info,
45+
const std::vector<TensorInfo> &out_info);
4246

4347
private:
4448
int64_t dim_;

0 commit comments

Comments
 (0)