Skip to content

Commit 5c1c77b

Browse files
authored
Merge branch 'PaddlePaddle:master' into master
2 parents 7103fd9 + 8fa1045 commit 5c1c77b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+3319
-10
lines changed

paconvert/api_mapping.json

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5437,7 +5437,7 @@
54375437
},
54385438
"torch.Tensor.addmm_": {},
54395439
"torch.Tensor.allclose": {
5440-
"Matcher": "GenericMatcher",
5440+
"Matcher": "AllcloseMatcher",
54415441
"paddle_api": "paddle.Tensor.allclose",
54425442
"args_list": [
54435443
"other",
@@ -5534,7 +5534,7 @@
55345534
},
55355535
"torch.Tensor.atan2_": {},
55365536
"torch.Tensor.all": {
5537-
"Matcher": "GenericMatcher",
5537+
"Matcher": "TensorToBoolMatcher",
55385538
"paddle_api": "paddle.Tensor.all",
55395539
"args_list": [
55405540
"dim",
@@ -5545,7 +5545,7 @@
55455545
}
55465546
},
55475547
"torch.Tensor.any": {
5548-
"Matcher": "GenericMatcher",
5548+
"Matcher": "TensorToBoolMatcher",
55495549
"paddle_api": "paddle.Tensor.any",
55505550
"args_list": [
55515551
"dim",
@@ -5579,7 +5579,7 @@
55795579
},
55805580
"torch.Tensor.bitwise_or": {
55815581
"Matcher": "GenericMatcher",
5582-
"paddle_api": "paddle.Tensor.bitwise_and",
5582+
"paddle_api": "paddle.Tensor.bitwise_or",
55835583
"args_list": [
55845584
"other"
55855585
],
@@ -5589,7 +5589,7 @@
55895589
},
55905590
"torch.Tensor.bitwise_xor": {
55915591
"Matcher": "GenericMatcher",
5592-
"paddle_api": "paddle.Tensor.bitwise_and",
5592+
"paddle_api": "paddle.Tensor.bitwise_xor",
55935593
"args_list": [
55945594
"other"
55955595
],
@@ -5601,18 +5601,21 @@
56015601
"Matcher": "GenericMatcher",
56025602
"paddle_api": "paddle.Tensor.bmm",
56035603
"args_list": [
5604-
"batch2"
5604+
"mat2"
56055605
],
56065606
"kwargs_change": {
5607-
"batch2": "y"
5607+
"mat2": "y"
56085608
}
56095609
},
56105610
"torch.Tensor.broadcast_to": {
56115611
"Matcher": "GenericMatcher",
56125612
"paddle_api": "paddle.Tensor.broadcast_to",
56135613
"args_list": [
5614-
"shape"
5615-
]
5614+
"size"
5615+
],
5616+
"kwargs_change": {
5617+
"size": "shape"
5618+
}
56165619
},
56175620
"torch.Tensor.ceil": {
56185621
"Matcher": "TensorUnchangeMatcher"
@@ -5690,7 +5693,7 @@
56905693
},
56915694
"torch.Tensor.arccosh": {
56925695
"Matcher": "GenericMatcher",
5693-
"paddle_api": "paddle.Tensor.arccosh"
5696+
"paddle_api": "paddle.Tensor.acosh"
56945697
},
56955698
"torch.Tensor.cross": {
56965699
"Matcher": "GenericMatcher",
@@ -7977,6 +7980,30 @@
79777980
"dim": "axes"
79787981
}
79797982
},
7983+
"torch.fft.fftshift": {
7984+
"Matcher": "GenericMatcher",
7985+
"paddle_api": "paddle.fft.fftshift",
7986+
"args_list": [
7987+
"input",
7988+
"dim"
7989+
],
7990+
"kwargs_change": {
7991+
"input": "x",
7992+
"dim": "axes"
7993+
}
7994+
},
7995+
"torch.fft.ifftshift": {
7996+
"Matcher": "GenericMatcher",
7997+
"paddle_api": "paddle.fft.ifftshift",
7998+
"args_list": [
7999+
"input",
8000+
"dim"
8001+
],
8002+
"kwargs_change": {
8003+
"input": "x",
8004+
"dim": "axes"
8005+
}
8006+
},
79808007
"torch.fft.irfftn": {
79818008
"Matcher": "GenericMatcher",
79828009
"paddle_api": "paddle.fft.irfftn",
@@ -8586,6 +8613,35 @@
85868613
"data_source"
85878614
]
85888615
},
8616+
"torch.utils.data.random_split": {
8617+
"Matcher": "RandomSplitMatcher",
8618+
"paddle_api": "paddle.io.random_split",
8619+
"args_list": [
8620+
"dataset",
8621+
"lengths",
8622+
"generator"
8623+
]
8624+
},
8625+
"torch.utils.dlpack.from_dlpack": {
8626+
"Matcher": "GenericMatcher",
8627+
"paddle_api": "paddle.utils.dlpack.from_dlpack",
8628+
"args_list": [
8629+
"ext_tensor"
8630+
],
8631+
"kwargs_change": {
8632+
"ext_tensor": "dlpack"
8633+
}
8634+
},
8635+
"torch.utils.dlpack.to_dlpack": {
8636+
"Matcher": "GenericMatcher",
8637+
"paddle_api": "paddle.utils.dlpack.to_dlpack",
8638+
"args_list": [
8639+
"tensor"
8640+
],
8641+
"kwargs_change": {
8642+
"tensor": "x"
8643+
}
8644+
},
85898645
"torch.nn.functional.l1_loss": {
85908646
"Matcher": "SizeAverageMatcher",
85918647
"paddle_api": "paddle.nn.functional.l1_loss",

paconvert/api_matcher.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3645,3 +3645,38 @@ class SizeAverageMatcher(BaseMatcher):
36453645
def generate_code(self, kwargs):
36463646
process_reduce_and_size_average(kwargs)
36473647
return GenericMatcher.generate_code(self, kwargs)
3648+
3649+
3650+
class RandomSplitMatcher(BaseMatcher):
3651+
def generate_code(self, kwargs):
3652+
API_TEMPLATE = textwrap.dedent(
3653+
"""
3654+
dataset_lengths = {}
3655+
if sum(dataset_lengths) <= 1:
3656+
dataset_lengths = [int(length * {}.__len__()) for length in dataset_lengths]
3657+
{}({})
3658+
"""
3659+
)
3660+
lenghts_v = kwargs["lengths"].strip("\n")
3661+
kwargs["lengths"] = "dataset_lengths"
3662+
code = API_TEMPLATE.format(
3663+
lenghts_v,
3664+
kwargs["dataset"],
3665+
self.get_paddle_api(),
3666+
self.kwargs_to_str(kwargs),
3667+
)
3668+
return code.strip("\n")
3669+
3670+
3671+
class TensorToBoolMatcher(BaseMatcher):
3672+
def generate_code(self, kwargs):
3673+
if "dim" in kwargs:
3674+
kwargs["axis"] = kwargs.pop("dim").strip("\n")
3675+
3676+
paddle_api = self.get_paddle_api()
3677+
paddle_api_name = paddle_api[paddle_api.rfind(".") :]
3678+
code = "{}({})".format(
3679+
self.paddleClass + ".astype('bool')" + paddle_api_name,
3680+
self.kwargs_to_str(kwargs),
3681+
)
3682+
return code

tests/test_Tensor_absolute.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import textwrap
17+
18+
from apibase import APIBase
19+
20+
obj = APIBase("torch.absolute")
21+
22+
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
a = torch.tensor([[-4, 9], [-23, 2]])
28+
result = a.absolute()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
result = torch.tensor([[-4, 9], [-23, 2]]).absolute()
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_acos.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import textwrap
17+
18+
from apibase import APIBase
19+
20+
obj = APIBase("torch.acos")
21+
22+
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
a = torch.tensor([[ 0.3348, -0.5889, 0.2005, -0.1584], [ 0.3348, -0.5889, 0.2005, -0.1584]])
28+
result = a.acos()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
result = torch.tensor([[ 0.3348, -0.5889, 0.2005, -0.1584]]).acos()
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_acosh.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
import textwrap
18+
19+
from apibase import APIBase
20+
21+
obj = APIBase("torch.Tensor.acosh")
22+
23+
24+
def test_case_1():
25+
pytorch_code = textwrap.dedent(
26+
"""
27+
import torch
28+
result = torch.tensor([1.3192, 1.9915, 1.9674, 1.7151]).acosh()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
a = torch.tensor([1.3192, 1.9915, 1.9674, 1.7151])
39+
result = a.acosh()
40+
"""
41+
)
42+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_addmm.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.Tensor.addmm")
20+
21+
22+
# The paddle input does not support integer type
23+
def _test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
x = torch.tensor([[1, 2], [4, 5]])
28+
mat1 = torch.tensor([[1, 2], [4, 5]])
29+
mat2 = torch.tensor([[1, 2], [4, 5]])
30+
result = x.addmm(mat1, mat2)
31+
"""
32+
)
33+
obj.run(pytorch_code, ["result"])
34+
35+
36+
def test_case_2():
37+
pytorch_code = textwrap.dedent(
38+
"""
39+
import torch
40+
x = torch.tensor([[1., 2], [4, 5]])
41+
mat1 = torch.tensor([[1., 2], [4, 5]])
42+
mat2 = torch.tensor([[1., 2], [4, 5]])
43+
result = x.addmm(mat1, mat2, beta=0.6, alpha=0.7)
44+
"""
45+
)
46+
obj.run(pytorch_code, ["result"])
47+
48+
49+
def test_case_3():
50+
pytorch_code = textwrap.dedent(
51+
"""
52+
import torch
53+
x = torch.tensor([[1., 2], [4, 5]])
54+
mat1 = torch.tensor([[1., 2], [4, 5]])
55+
mat2 = torch.tensor([[1., 2], [4, 5]])
56+
result = x.addmm(mat1=mat1, mat2=mat2, beta=0.6, alpha=0.7)
57+
"""
58+
)
59+
obj.run(pytorch_code, ["result"])
60+
61+
62+
def test_case_4():
63+
pytorch_code = textwrap.dedent(
64+
"""
65+
import torch
66+
x = torch.tensor([[1., 2], [4, 5]])
67+
mat1 = torch.tensor([[1., 2], [4, 5]])
68+
result = x.addmm(mat1, torch.tensor([[1., 2], [4, 5]]), beta=0.6, alpha=0.7)
69+
"""
70+
)
71+
obj.run(pytorch_code, ["result"])
72+
73+
74+
def test_case_5():
75+
pytorch_code = textwrap.dedent(
76+
"""
77+
import torch
78+
x = torch.tensor([[1., 2], [4, 5]])
79+
result = x.addmm(torch.tensor([[1., 2], [4, 5]]), torch.tensor([[1., 2], [4, 5]]))
80+
"""
81+
)
82+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)