Skip to content

Commit a0b7c04

Browse files
update roll (#1228)
1 parent 83ecf73 commit a0b7c04

File tree

5 files changed

+161
-3
lines changed

5 files changed

+161
-3
lines changed

.gitignore

+12-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
*.pyc
66
.pydevproject
7-
build/*
87
.eggs/*
98
dist/*
109
.setuptools*
@@ -13,4 +12,15 @@ paddle2onnx.egg-info/*
1312
*_*.onnx
1413
*.log
1514
version.py
16-
paddle2onnx/mappers_registry.h
15+
paddle2onnx/mappers_registry.h
16+
17+
# CMD
18+
build/*
19+
paddle2onnx-*
20+
21+
# Clion
22+
cmake-build-*
23+
.idea
24+
25+
# VSCode
26+
.vscode

VERSION_NUMBER

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.2.0
1+
1.2.1

paddle2onnx/mapper/tensor/roll.cc

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <limits>
16+
#include "paddle2onnx/mapper/tensor/roll.h"
17+
18+
namespace paddle2onnx {
19+
REGISTER_MAPPER(roll, RollMapper)
20+
21+
void RollMapper::Opset7() {
22+
auto input_info = GetInput("X");
23+
auto output_info = GetOutput("Out");
24+
25+
std::vector<int64_t> shifts;
26+
GetAttr("shifts", &shifts);
27+
28+
std::vector<int64_t> axis;
29+
GetAttr("axis", &axis);
30+
31+
std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
32+
auto result_name = input_info[0].name;
33+
if (axis.empty())
34+
{
35+
int64_t axes = 0;
36+
result_name = helper_->Flatten(result_name);
37+
for(int i = 0;i < shifts.size();i++) {
38+
auto shift = shifts[i];
39+
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
40+
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
41+
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
42+
AddAttribute(temp_node, "axis", axes);
43+
result_name = temp_node->output(0);
44+
}
45+
helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
46+
// helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
47+
} else {
48+
for(int i = 0;i < shifts.size();i++) {
49+
auto shift = shifts[i];
50+
int64_t axes = axis[i];
51+
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
52+
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
53+
if(i+1 == shifts.size()) {
54+
temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
55+
} else {
56+
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
57+
}
58+
AddAttribute(temp_node, "axis", axes);
59+
result_name = temp_node->output(0);
60+
}
61+
}
62+
}
63+
} // namespace paddle2onnx

paddle2onnx/mapper/tensor/roll.h

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include <string>
17+
#include <vector>
18+
19+
#include "paddle2onnx/mapper/mapper.h"
20+
21+
namespace paddle2onnx {
22+
23+
class RollMapper : public Mapper {
24+
public:
25+
RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
26+
int64_t op_id)
27+
: Mapper(p, helper, block_id, op_id) {}
28+
void Opset7();
29+
};
30+
31+
} // namespace paddle2onnx

tests/test_roll.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from onnxbase import APIOnnx
17+
from onnxbase import randtool
18+
19+
20+
class Net(paddle.nn.Layer):
21+
"""
22+
simple Net
23+
"""
24+
25+
def __init__(self):
26+
super(Net, self).__init__()
27+
28+
def forward(self, inputs):
29+
"""
30+
forward
31+
"""
32+
x = paddle.roll(inputs, 1)
33+
return x
34+
35+
36+
def test_roll():
37+
"""
38+
api: paddle.roll
39+
op version: 9
40+
"""
41+
op = Net()
42+
op.eval()
43+
# net, name, ver_list, delta=1e-6, rtol=1e-5
44+
obj = APIOnnx(op, 'roll', [9])
45+
input_data = paddle.to_tensor(randtool("float", -1, 1, [2,2]).astype('float32'))
46+
print(input_data)
47+
obj.set_input_data(
48+
"input_data",
49+
input_data
50+
)
51+
obj.run()
52+
53+
if __name__ == "__main__":
54+
test_roll()

0 commit comments

Comments
 (0)