@@ -23,9 +23,68 @@ int32_t RepeatInterleaveMapper::GetMinOpsetVersion(bool verbose) {
23
23
return op_version;
24
24
}
25
25
26
- void RepeatInterleaveMapper::Opset9 () {
27
- auto x_info = GetInput (" X" ); // shape = [1, 2, 3]
28
- auto out_info = GetOutput (" Out" );
26
+ void RepeatInterleaveMapper::DynamicRepeatInterleave (
27
+ const std::vector<TensorInfo>& x_info,
28
+ const std::vector<TensorInfo>& out_info) {
29
+ std::string dim_size_name = " " ;
30
+
31
+ auto shape_node = helper_->MakeNode (" Shape" , {x_info[0 ].name }, 1 );
32
+ auto dim_node =
33
+ helper_->MakeNode (" Gather" ,
34
+ {shape_node->output (0 ),
35
+ helper_->Constant (ONNX_NAMESPACE::TensorProto::INT64,
36
+ std::vector<int64_t >{dim_})},
37
+ 1 );
38
+ dim_size_name = dim_node->output (0 );
39
+
40
+ std::string repeat_info_name = " " ;
41
+ int64_t repeat = 0 ;
42
+ if (in_pir_mode) {
43
+ if (OpType () == " pd_op.repeat_interleave" ) {
44
+ GetAttr (" repeats" , &repeat);
45
+ }
46
+ } else {
47
+ GetAttr (" Repeats" , &repeat);
48
+ }
49
+
50
+ if (HasInput (" RepeatTensor" )) {
51
+ auto tmp_info = GetInput (" RepeatTensor" );
52
+ repeat_info_name = helper_->AutoCast (
53
+ tmp_info[0 ].name , tmp_info[0 ].dtype , P2ODataType::INT64);
54
+ } else if (repeat != 0 ) {
55
+ auto repeat_node =
56
+ helper_->MakeNode (" Expand" ,
57
+ {helper_->Constant (ONNX_NAMESPACE::TensorProto::INT64,
58
+ std::vector<int64_t >{repeat}),
59
+ dim_size_name},
60
+ 1 );
61
+ repeat_info_name = repeat_node->output (0 );
62
+ }
63
+
64
+ std::vector<std::string> split_input_names;
65
+
66
+ split_input_names.push_back (x_info[0 ].name );
67
+
68
+ std::vector<std::string> output_names;
69
+ int x_shape_size = x_info[0 ].shape .size ();
70
+
71
+ std::string prefix_name = helper_->Constant (
72
+ ONNX_NAMESPACE::TensorProto::INT64, std::vector<int64_t >(dim_, 1 ));
73
+ std::string suffix_name =
74
+ helper_->Constant (ONNX_NAMESPACE::TensorProto::INT64,
75
+ std::vector<int64_t >(x_shape_size - dim_ - 1 , 1 ));
76
+
77
+ std::string tile_name =
78
+ helper_->Concat ({prefix_name, repeat_info_name, suffix_name}, 0 );
79
+ auto node = helper_->MakeNode (" Tile" , {x_info[0 ].name , tile_name}, 1 );
80
+ output_names.push_back (node->output (0 ));
81
+
82
+ helper_->Concat (output_names, out_info[0 ].name , dim_);
83
+ }
84
+
85
+ void RepeatInterleaveMapper::StaticRepeatInterleave (
86
+ const std::vector<TensorInfo>& x_info,
87
+ const std::vector<TensorInfo>& out_info) {
29
88
int n = x_info[0 ].shape [dim_];
30
89
int x_shape_size = x_info[0 ].shape .size ();
31
90
@@ -79,4 +138,23 @@ void RepeatInterleaveMapper::Opset9() {
79
138
}
80
139
helper_->Concat (output_names, out_info[0 ].name , dim_);
81
140
}
141
+
142
+ void RepeatInterleaveMapper::Opset9 () {
143
+ auto x_info = GetInput (" X" );
144
+ auto out_info = GetOutput (" Out" );
145
+
146
+ bool is_dynamic_shape = false ;
147
+ for (auto dim : x_info[0 ].shape ) {
148
+ if (dim == -1 ) {
149
+ is_dynamic_shape = true ;
150
+ break ;
151
+ }
152
+ }
153
+
154
+ if (is_dynamic_shape) {
155
+ DynamicRepeatInterleave (x_info, out_info);
156
+ } else {
157
+ StaticRepeatInterleave (x_info, out_info);
158
+ }
159
+ }
82
160
} // namespace paddle2onnx
0 commit comments