Skip to content

Commit fd0205c

Browse files
authored
PyTorch convertor for Seg SwimTransformer (#637)
* Update stargan.md * fix the paddle_type * add docs * add docs * add acknowledge * fix the docs * fix the docs * add docs * fix * add docs * add docs * Update README.md * fix onnx inputs * fix * fix * remove * remove numpy input * add for seg swin transformer * add pad * add pad * fix onnx
1 parent 32eaafe commit fd0205c

File tree

6 files changed

+141
-25
lines changed

6 files changed

+141
-25
lines changed

x2paddle/core/program.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def add_layer(self, kernel, inputs, outputs, scope_name="", **kwargs):
109109
layer = PaddleLayer(
110110
layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
111111
self.layers[layer_id] = layer
112-
if layer.kernel in ["prim.list_unpack" or "prim.tuple_unpack"]:
112+
if layer.kernel in ["prim.list_unpack" , "prim.tuple_unpack"]:
113113
self.has_unpack = True
114114
return layer_id
115115

x2paddle/op_mapper/onnx2paddle/opset9/opset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1514,7 +1514,7 @@ def PRelu(self, node):
15141514
"paddle.minimum",
15151515
inputs={"x": val_x.name,
15161516
"y": output_name + "__zeros"},
1517-
outputs=[output_name + "__max"])
1517+
outputs=[output_name + "__min"])
15181518
self.paddle_graph.add_layer(
15191519
"paddle.multiply",
15201520
inputs={"x": val_slope.name,

x2paddle/op_mapper/pytorch2paddle/aten.py

+75-23
Original file line numberDiff line numberDiff line change
@@ -983,23 +983,31 @@ def aten_constant_pad_nd(mapper, graph, node):
983983
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
984984
scope_name)
985985
layer_inputs["input"] = inputs_name[0]
986+
# 处理输入1,即%4876
987+
is_padding_tensor = False
988+
if inputs_name[1] in mapper.attrs:
989+
layer_attrs["padding"] = mapper.attrs[inputs_name[1]]
990+
else:
991+
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
992+
scope_name)
993+
layer_inputs["pad"] = inputs_name[1]
994+
is_padding_tensor = True
986995
# 获取当前节点输入的list
987996
current_inputs = list(layer_inputs.values())
988-
# 处理输入1,即%4876
989-
layer_attrs["padding"] = mapper.attrs[inputs_name[1]]
990997
# 处理输入2,即%42
991998
layer_attrs["value"] = mapper.attrs[inputs_name[2]]
992999

993-
graph.add_layer(
994-
"prim.shape",
995-
inputs={"input": inputs_name[0]},
996-
outputs=[inputs_name[0] + "_shape"],
997-
scope_name=scope_name)
998-
graph.add_layer(
999-
"prim.len",
1000-
inputs={"input": inputs_name[0] + "_shape"},
1001-
outputs=[inputs_name[0] + "_len"],
1002-
scope_name=scope_name)
1000+
if not is_padding_tensor:
1001+
graph.add_layer(
1002+
"prim.shape",
1003+
inputs={"input": inputs_name[0]},
1004+
outputs=[inputs_name[0] + "_shape"],
1005+
scope_name=scope_name)
1006+
graph.add_layer(
1007+
"prim.len",
1008+
inputs={"input": inputs_name[0] + "_shape"},
1009+
outputs=[inputs_name[0] + "_len"],
1010+
scope_name=scope_name)
10031011

10041012
def add_pad_layers(kernel, dim):
10051013
graph.add_layer(
@@ -1020,6 +1028,7 @@ def add_pad_layers(kernel, dim):
10201028
inputs={"y": inputs_name[0] + "_len"},
10211029
outputs=[inputs_name[0] + "_len0"],
10221030
scope_name=scope_name,
1031+
alpha=1.0,
10231032
x=dim)
10241033
block.add_layer(
10251034
"prim.len2list",
@@ -1058,17 +1067,25 @@ def add_pad_layers(kernel, dim):
10581067
if_layer.inputs["input-0"] = inputs_name[0]
10591068
if_layer.inputs["input-1"] = inputs_name[0] + "_len"
10601069

1061-
if len(layer_attrs["padding"]) == 2:
1062-
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad1d")
1063-
add_pad_layers("paddle.nn.Pad1D", 3)
1064-
elif len(layer_attrs["padding"]) == 4:
1065-
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad2d")
1066-
add_pad_layers("paddle.nn.Pad2D", 4)
1067-
elif len(layer_attrs["padding"]) == 6:
1068-
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad3d")
1069-
add_pad_layers("paddle.nn.Pad3D", 5)
1070+
if not is_padding_tensor:
1071+
if len(layer_attrs["padding"]) == 2:
1072+
layer_outputs[0] = layer_outputs[0].replace("pad", "pad1d")
1073+
add_pad_layers("paddle.nn.Pad1D", 3)
1074+
elif len(layer_attrs["padding"]) == 4:
1075+
layer_outputs[0] = layer_outputs[0].replace("pad", "pad2d")
1076+
add_pad_layers("paddle.nn.Pad2D", 4)
1077+
elif len(layer_attrs["padding"]) == 6:
1078+
layer_outputs[0] = layer_outputs[0].replace("pad", "pad3d")
1079+
add_pad_layers("paddle.nn.Pad3D", 5)
1080+
else:
1081+
raise Exception("The lenght of padding list must be 2, 4 or 6!")
10701082
else:
1071-
raise Exception("The lenght of padding list must be 2, 4 or 6!")
1083+
graph.add_layer(
1084+
"custom_layer:Pad",
1085+
inputs=layer_inputs,
1086+
outputs=[output_name],
1087+
scope_name=scope_name,
1088+
**layer_attrs)
10721089
return current_inputs, current_outputs
10731090

10741091

@@ -4191,10 +4208,45 @@ def aten_relu6(mapper, graph, node):
41914208
return current_inputs, current_outputs
41924209

41934210

4211+
def aten_remainder(mapper, graph, node):
4212+
""" 构造取余数的PaddleLayer。
4213+
TorchScript示例:
4214+
%701 : Tensor = aten::remainder(%661, %139)
4215+
参数含义:
4216+
%701 (Tensor): 输出,取余结果的Tensor。
4217+
%661 (Tensor): 需要取余的Tensor。
4218+
%139 (Tensor): 除数Tensor。
4219+
"""
4220+
scope_name = mapper.normalize_scope_name(node)
4221+
output_name = mapper._get_outputs_name(node)[0]
4222+
layer_outputs = [output_name]
4223+
layer_inputs = {}
4224+
inputs_name, inputs_node = mapper._get_inputs_name(node)
4225+
# 获取当前节点输出的list
4226+
current_outputs = [output_name]
4227+
# 处理输入0,即%661
4228+
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
4229+
scope_name)
4230+
layer_inputs["x"] = inputs_name[0]
4231+
# 处理输入1,即%139
4232+
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
4233+
scope_name)
4234+
layer_inputs["y"] = inputs_name[1]
4235+
# 获取当前节点输入、输出的list
4236+
current_inputs = list(layer_inputs.values())
4237+
4238+
graph.add_layer(
4239+
"prim.remainder",
4240+
inputs=layer_inputs,
4241+
outputs=layer_outputs,
4242+
scope_name=scope_name)
4243+
return current_inputs, current_outputs
4244+
4245+
41944246
def aten_repeat(mapper, graph, node):
41954247
""" 构造根据参数对输入各维度进行复制的PaddleLayer。
41964248
TorchScript示例:
4197-
701 : Tensor = aten::repeat(%699, %700)
4249+
%701 : Tensor = aten::repeat(%699, %700)
41984250
参数含义:
41994251
%701 (Tensor): 输出,复制后的Tensor。
42004252
%699 (Tensor): 需要复制的Tensor。

x2paddle/op_mapper/pytorch2paddle/prim2code.py

+15
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,21 @@ def prim_or(layer,
609609
if is_return_line:
610610
return line.split(" = ")[1]
611611
forward_func.extend(gen_codes([line], indent=indent))
612+
613+
614+
def prim_remainder(layer,
615+
indent=1,
616+
init_func=[],
617+
forward_func=[],
618+
layer_id=None,
619+
different_attrs=None,
620+
is_return_line=False):
621+
line = "{} = {} % {}".format(layer.outputs[0],
622+
get_value(layer, "x", different_attrs),
623+
get_value(layer, "y", different_attrs))
624+
if is_return_line:
625+
return line.split(" = ")[1]
626+
forward_func.extend(gen_codes([line], indent=indent))
612627

613628

614629
def prim_replaceitem(layer,

x2paddle/op_mapper/pytorch2paddle/pytorch_custom_layer/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
from .gather import Gather
1616
from .instance_norm import InstanceNorm
17+
from .pad import Pad
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
from x2paddle.core.util import *
17+
18+
19+
class Pad(object):
20+
def __init__(self, value):
21+
self.value = value
22+
23+
def __call__(self, input, pad):
24+
shape = input.shape
25+
dim = len(shape)
26+
if len(pad) == 2:
27+
data_format = "NCL"
28+
elif len(pad) == 4:
29+
data_format = "NCHW"
30+
elif len(pad) == 6:
31+
data_format = "NCDHW"
32+
if dim == 3 and len(pad) == 4:
33+
input = paddle.unsqueeze(input, [0])
34+
output = paddle.nn.functional.pad(input,
35+
pad,
36+
data_format=data_format)
37+
output = paddle.squeeze(output, [0])
38+
elif dim == 4 and len(pad) == 6:
39+
input = paddle.unsqueeze(input, [0])
40+
output = paddle.nn.functional.pad(input,
41+
pad,
42+
data_format=data_format)
43+
output = paddle.squeeze(output, [0])
44+
else:
45+
output = paddle.nn.functional.pad(input,
46+
pad,
47+
data_format=data_format)
48+
return output

0 commit comments

Comments
 (0)