Skip to content

Commit c1fb632

Browse files
authored
【Hackathon 5th No.47】API转换 103-124 (#346)
1 parent b51d791 commit c1fb632

31 files changed

+1912
-243
lines changed

paconvert/api_mapping.json

Lines changed: 316 additions & 16 deletions
Large diffs are not rendered by default.

paconvert/api_matcher.py

Lines changed: 33 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,20 @@ def get_paddle_class_nodes(self, func, args, kwargs):
169169
return "delete"
170170

171171

172+
class AtleastMatcher(BaseMatcher):
173+
def get_paddle_nodes(self, args, kwargs):
174+
new_args = self.parse_args(args)
175+
if new_args[0][0] == "(" and new_args[0][-1] == ")":
176+
new_args[0] = new_args[0][1:-1]
177+
if new_args[0][0] == "[" and new_args[0][-1] == "]":
178+
new_args[0] = new_args[0][1:-1]
179+
new_kwargs = self.parse_kwargs(kwargs)
180+
code = "{}({})".format(
181+
self.get_paddle_api(), self.args_and_kwargs_to_str(new_args, new_kwargs)
182+
)
183+
return ast.parse(code).body
184+
185+
172186
class UnchangeMatcher(BaseMatcher):
173187
def get_paddle_class_attribute_nodes(self, node):
174188
return "unchange"
@@ -1549,7 +1563,11 @@ def generate_code(self, kwargs):
15491563
if len(kwargs) == 0:
15501564
code = f"str({self.paddleClass}.dtype)"
15511565
else:
1552-
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
1566+
# For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type
1567+
if "dst_type" in kwargs:
1568+
code = f"{self.paddleClass}.astype({kwargs['dst_type']})"
1569+
else:
1570+
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
15531571
return code
15541572

15551573

@@ -2111,61 +2129,21 @@ def get_paddle_nodes(self, args, kwargs):
21112129

21122130

21132131
class TensorToMatcher(BaseMatcher):
2114-
def generate_aux_code(self):
2115-
CODE_TEMPLATE = textwrap.dedent(
2116-
"""
2117-
def to(self, *args, **kwargs):
2118-
args_list = ["x", "y", "non_blocking", "copy", "memory_format"]
2119-
new_kwargs = {}
2120-
for i, node in enumerate(args):
2121-
k = args_list[i]
2122-
new_kwargs[k] = node
2123-
for node in kwargs:
2124-
v = kwargs[node]
2125-
new_kwargs[node] = v
2126-
kwargs = new_kwargs
2127-
if not kwargs:
2128-
return self
2129-
elif "tensor" in kwargs:
2130-
return paddle.cast(self, "{}.dtype".format(kwargs["tensor"]))
2131-
elif "dtype" in kwargs:
2132-
return paddle.cast(self, "{}".format(kwargs["dtype"]))
2133-
elif "device" in kwargs and "dtype" not in kwargs:
2134-
return self
2135-
elif kwargs:
2136-
if "y" not in kwargs and "x" in kwargs:
2137-
if isinstance(kwargs["x"], paddle.dtype):
2138-
dtype = kwargs["x"]
2139-
elif isinstance(kwargs["x"], str) and kwargs["x"] not in ['cpu', 'cuda', 'ipu', 'xpu']:
2140-
dtype = kwargs["x"]
2141-
elif isinstance(kwargs["x"], paddle.Tensor):
2142-
dtype = kwargs["x"].dtype
2143-
else:
2144-
dtype = self.dtype
2145-
return paddle.cast(self, dtype)
2146-
2147-
elif "y" in kwargs and "x" in kwargs:
2148-
if isinstance(kwargs["x"], paddle.dtype):
2149-
dtype = kwargs["x"]
2150-
elif isinstance(kwargs["x"], str):
2151-
if x not in ['cpu', 'cuda', 'ipu', 'xpu']:
2152-
dtype = kwargs["x"]
2153-
else:
2154-
dtype = kwargs["y"] if isinstance(kwargs["y"], str) else self.dtype
2155-
else:
2156-
dtype = kwargs["x"]
2157-
return paddle.cast(self, dtype)
2158-
else:
2159-
return self
2160-
2161-
setattr(paddle.Tensor, 'to', to)
2162-
"""
2132+
def get_paddle_nodes(self, args, kwargs):
2133+
new_args = self.parse_args(args)
2134+
new_kwargs = self.parse_kwargs(kwargs)
2135+
if new_kwargs is None:
2136+
return None
2137+
if "copy" in new_kwargs:
2138+
new_kwargs.pop("copy")
2139+
if "memory_format" in new_kwargs:
2140+
new_kwargs.pop("memory_format")
2141+
if "non_blocking" in new_kwargs:
2142+
new_kwargs["blocking"] = "not " + new_kwargs.pop("non_blocking").strip("()")
2143+
code = "{}.to({})".format(
2144+
self.paddleClass, self.args_and_kwargs_to_str(new_args, new_kwargs)
21632145
)
2164-
return CODE_TEMPLATE
2165-
2166-
def get_paddle_class_nodes(self, func, args, kwargs):
2167-
self.write_aux_code()
2168-
return "unchange"
2146+
return ast.parse(code).body
21692147

21702148

21712149
class TensorRequiresGradMatcher(BaseMatcher):
@@ -2964,24 +2942,6 @@ def get_paddle_nodes(self, args, kwargs):
29642942
return ast.parse(code).body
29652943

29662944

2967-
class HypotMatcher(BaseMatcher):
2968-
def generate_code(self, kwargs):
2969-
if "input" not in kwargs:
2970-
kwargs["input"] = self.paddleClass
2971-
2972-
API_TEMPLATE = textwrap.dedent(
2973-
"""
2974-
paddle.pow({}**2 + {}**2, 1/2)
2975-
"""
2976-
)
2977-
code = API_TEMPLATE.format(kwargs["input"], kwargs["other"])
2978-
2979-
if "out" in kwargs and kwargs["out"] != "None":
2980-
code = "paddle.assign({}, output={})".format(code, kwargs["out"])
2981-
2982-
return code
2983-
2984-
29852945
class TensorViewMatcher(BaseMatcher):
29862946
def generate_aux_code(self):
29872947
CODE_TEMPLATE = textwrap.dedent(

tests/test_Tensor_diagonal_scatter.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,48 +23,57 @@ def test_case_1():
2323
pytorch_code = textwrap.dedent(
2424
"""
2525
import torch
26-
input = torch.zeros(3, 3)
27-
src = torch.ones(3)
26+
input = torch.arange(6.0).reshape((2, 3))
27+
src = torch.ones((2,))
2828
result = input.diagonal_scatter(src)
2929
"""
3030
)
31-
obj.run(
32-
pytorch_code,
33-
["result"],
34-
unsupport=True,
35-
reason="paddle does not support this function temporarily",
36-
)
31+
obj.run(pytorch_code, ["result"])
3732

3833

3934
def test_case_2():
4035
pytorch_code = textwrap.dedent(
4136
"""
4237
import torch
43-
input = torch.zeros(3, 3)
44-
src = torch.ones(3)
38+
input = torch.arange(6.0).reshape((2, 3))
39+
src = torch.ones((2,))
4540
result = input.diagonal_scatter(src=src)
4641
"""
4742
)
48-
obj.run(
49-
pytorch_code,
50-
["result"],
51-
unsupport=True,
52-
reason="paddle does not support this function temporarily",
53-
)
43+
obj.run(pytorch_code, ["result"])
5444

5545

5646
def test_case_3():
5747
pytorch_code = textwrap.dedent(
5848
"""
5949
import torch
60-
input = torch.zeros(3, 3)
61-
src = torch.ones(3)
62-
result = input.diagonal_scatter(src=src, offset=0, dim1=-2)
50+
input = torch.arange(6.0).reshape((2, 3))
51+
src = torch.ones((2,))
52+
result = input.diagonal_scatter(offset=0, src=src, dim2=1, dim1=-2)
6353
"""
6454
)
65-
obj.run(
66-
pytorch_code,
67-
["result"],
68-
unsupport=True,
69-
reason="paddle does not support this function temporarily",
55+
obj.run(pytorch_code, ["result"])
56+
57+
58+
def test_case_4():
59+
pytorch_code = textwrap.dedent(
60+
"""
61+
import torch
62+
input = torch.arange(6.0).reshape((2, 3))
63+
src = torch.ones((2,))
64+
result = input.diagonal_scatter(src=src, offset=0, dim1=-2, dim2=1)
65+
"""
66+
)
67+
obj.run(pytorch_code, ["result"])
68+
69+
70+
def test_case_5():
71+
pytorch_code = textwrap.dedent(
72+
"""
73+
import torch
74+
input = torch.arange(6.0).reshape((2, 3))
75+
src = torch.ones((2,))
76+
result = input.diagonal_scatter(src, 0, -2, 1)
77+
"""
7078
)
79+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_hypot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,15 @@ def test_case_4():
6464
"""
6565
)
6666
obj.run(pytorch_code, ["result"])
67+
68+
69+
def test_case_5():
70+
pytorch_code = textwrap.dedent(
71+
"""
72+
import torch
73+
a = torch.tensor([1., 2, 3])
74+
b = torch.tensor([4., 5, 6])
75+
result = a.hypot(b+1)
76+
"""
77+
)
78+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_hypot_.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import textwrap
15+
16+
from apibase import APIBase
17+
18+
obj = APIBase("torch.Tensor.hypot_")
19+
20+
21+
def test_case_1():
22+
pytorch_code = textwrap.dedent(
23+
"""
24+
import torch
25+
a = torch.tensor([1., 2, 3])
26+
b = torch.tensor([4., 5, 6])
27+
result = a.hypot_(b)
28+
"""
29+
)
30+
obj.run(pytorch_code, ["result", "a"])
31+
32+
33+
def test_case_2():
34+
pytorch_code = textwrap.dedent(
35+
"""
36+
import torch
37+
a = torch.tensor([1., 2, 3])
38+
b = torch.tensor([4., 5, 6])
39+
result = a.hypot_(other=b)
40+
"""
41+
)
42+
obj.run(pytorch_code, ["result", "a"])
43+
44+
45+
def test_case_3():
46+
pytorch_code = textwrap.dedent(
47+
"""
48+
import torch
49+
a = torch.tensor([-1., 2, 3])
50+
b = torch.tensor([4., 5, 6])
51+
result = a.hypot_(other=b)
52+
"""
53+
)
54+
obj.run(pytorch_code, ["result", "a"])
55+
56+
57+
def test_case_4():
58+
pytorch_code = textwrap.dedent(
59+
"""
60+
import torch
61+
a = torch.tensor([1., 2, 3])
62+
b = torch.tensor([4., 5, 6])
63+
result = a.hypot_(other=b+1)
64+
"""
65+
)
66+
obj.run(pytorch_code, ["result", "a"])
67+
68+
69+
def test_case_5():
70+
pytorch_code = textwrap.dedent(
71+
"""
72+
import torch
73+
a = torch.tensor([1., 2, 3])
74+
b = torch.tensor([4., 5, 6])
75+
result = a.hypot_(b+1)
76+
"""
77+
)
78+
obj.run(pytorch_code, ["result", "a"])

tests/test_Tensor_index_fill.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.Tensor.index_fill")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
x = torch.eye(2, 4)
27+
indices = torch.tensor([0, 1])
28+
value = -1
29+
result = x.index_fill(0, indices, value)
30+
"""
31+
)
32+
obj.run(pytorch_code, ["result"])
33+
34+
35+
def test_case_2():
36+
pytorch_code = textwrap.dedent(
37+
"""
38+
import torch
39+
indices = torch.tensor([0, 1])
40+
value = -1
41+
result = torch.eye(3, 4).index_fill(1, indices, value)
42+
"""
43+
)
44+
obj.run(pytorch_code, ["result"])
45+
46+
47+
def test_case_3():
48+
pytorch_code = textwrap.dedent(
49+
"""
50+
import torch
51+
indices = torch.tensor([0, 1])
52+
dim = 0
53+
value = -1
54+
result = torch.eye(3, 4).index_fill(index=indices, dim=dim, value=value)
55+
"""
56+
)
57+
obj.run(pytorch_code, ["result"])
58+
59+
60+
def test_case_4():
61+
pytorch_code = textwrap.dedent(
62+
"""
63+
import torch
64+
indices = torch.tensor([0, 3])
65+
dim = 0
66+
value = -1
67+
result = torch.eye(6, 4).index_fill(dim=dim, index=indices, value=value)
68+
"""
69+
)
70+
obj.run(pytorch_code, ["result"])
71+
72+
73+
def test_case_5():
74+
pytorch_code = textwrap.dedent(
75+
"""
76+
import torch
77+
indices = torch.tensor([0, 3])
78+
value = -1
79+
result = torch.eye(3, 4).index_fill(1, indices, value)
80+
"""
81+
)
82+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)