Skip to content

Commit 8fa1045

Browse files
authored
Add Test cases-6-22 (#127)
* Add test cases 6-22 * fix bug of torch.Tensor.broadcast_to * Fix bugs * Fix bugs
1 parent 8863e7b commit 8fa1045

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2708
-10
lines changed

paconvert/api_mapping.json

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5437,7 +5437,7 @@
54375437
},
54385438
"torch.Tensor.addmm_": {},
54395439
"torch.Tensor.allclose": {
5440-
"Matcher": "GenericMatcher",
5440+
"Matcher": "AllcloseMatcher",
54415441
"paddle_api": "paddle.Tensor.allclose",
54425442
"args_list": [
54435443
"other",
@@ -5534,7 +5534,7 @@
55345534
},
55355535
"torch.Tensor.atan2_": {},
55365536
"torch.Tensor.all": {
5537-
"Matcher": "GenericMatcher",
5537+
"Matcher": "TensorToBoolMatcher",
55385538
"paddle_api": "paddle.Tensor.all",
55395539
"args_list": [
55405540
"dim",
@@ -5545,7 +5545,7 @@
55455545
}
55465546
},
55475547
"torch.Tensor.any": {
5548-
"Matcher": "GenericMatcher",
5548+
"Matcher": "TensorToBoolMatcher",
55495549
"paddle_api": "paddle.Tensor.any",
55505550
"args_list": [
55515551
"dim",
@@ -5579,7 +5579,7 @@
55795579
},
55805580
"torch.Tensor.bitwise_or": {
55815581
"Matcher": "GenericMatcher",
5582-
"paddle_api": "paddle.Tensor.bitwise_and",
5582+
"paddle_api": "paddle.Tensor.bitwise_or",
55835583
"args_list": [
55845584
"other"
55855585
],
@@ -5589,7 +5589,7 @@
55895589
},
55905590
"torch.Tensor.bitwise_xor": {
55915591
"Matcher": "GenericMatcher",
5592-
"paddle_api": "paddle.Tensor.bitwise_and",
5592+
"paddle_api": "paddle.Tensor.bitwise_xor",
55935593
"args_list": [
55945594
"other"
55955595
],
@@ -5601,18 +5601,21 @@
56015601
"Matcher": "GenericMatcher",
56025602
"paddle_api": "paddle.Tensor.bmm",
56035603
"args_list": [
5604-
"batch2"
5604+
"mat2"
56055605
],
56065606
"kwargs_change": {
5607-
"batch2": "y"
5607+
"mat2": "y"
56085608
}
56095609
},
56105610
"torch.Tensor.broadcast_to": {
56115611
"Matcher": "GenericMatcher",
56125612
"paddle_api": "paddle.Tensor.broadcast_to",
56135613
"args_list": [
5614-
"shape"
5615-
]
5614+
"size"
5615+
],
5616+
"kwargs_change": {
5617+
"size": "shape"
5618+
}
56165619
},
56175620
"torch.Tensor.ceil": {
56185621
"Matcher": "TensorUnchangeMatcher"
@@ -5690,7 +5693,7 @@
56905693
},
56915694
"torch.Tensor.arccosh": {
56925695
"Matcher": "GenericMatcher",
5693-
"paddle_api": "paddle.Tensor.arccosh"
5696+
"paddle_api": "paddle.Tensor.acosh"
56945697
},
56955698
"torch.Tensor.cross": {
56965699
"Matcher": "GenericMatcher",

paconvert/api_matcher.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,3 +3666,17 @@ def generate_code(self, kwargs):
36663666
self.kwargs_to_str(kwargs),
36673667
)
36683668
return code.strip("\n")
3669+
3670+
3671+
class TensorToBoolMatcher(BaseMatcher):
3672+
def generate_code(self, kwargs):
3673+
if "dim" in kwargs:
3674+
kwargs["axis"] = kwargs.pop("dim").strip("\n")
3675+
3676+
paddle_api = self.get_paddle_api()
3677+
paddle_api_name = paddle_api[paddle_api.rfind(".") :]
3678+
code = "{}({})".format(
3679+
self.paddleClass + ".astype('bool')" + paddle_api_name,
3680+
self.kwargs_to_str(kwargs),
3681+
)
3682+
return code

tests/test_Tensor_absolute.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import textwrap
17+
18+
from apibase import APIBase
19+
20+
obj = APIBase("torch.absolute")
21+
22+
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
a = torch.tensor([[-4, 9], [-23, 2]])
28+
result = a.absolute()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
result = torch.tensor([[-4, 9], [-23, 2]]).absolute()
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_acos.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import textwrap
17+
18+
from apibase import APIBase
19+
20+
obj = APIBase("torch.acos")
21+
22+
23+
def test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
a = torch.tensor([[ 0.3348, -0.5889, 0.2005, -0.1584], [ 0.3348, -0.5889, 0.2005, -0.1584]])
28+
result = a.acos()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
result = torch.tensor([[ 0.3348, -0.5889, 0.2005, -0.1584]]).acos()
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_acosh.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
17+
import textwrap
18+
19+
from apibase import APIBase
20+
21+
obj = APIBase("torch.Tensor.acosh")
22+
23+
24+
def test_case_1():
25+
pytorch_code = textwrap.dedent(
26+
"""
27+
import torch
28+
result = torch.tensor([1.3192, 1.9915, 1.9674, 1.7151]).acosh()
29+
"""
30+
)
31+
obj.run(pytorch_code, ["result"])
32+
33+
34+
def test_case_2():
35+
pytorch_code = textwrap.dedent(
36+
"""
37+
import torch
38+
a = torch.tensor([1.3192, 1.9915, 1.9674, 1.7151])
39+
result = a.acosh()
40+
"""
41+
)
42+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_addmm.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.Tensor.addmm")
20+
21+
22+
# The paddle input does not support integer type
23+
def _test_case_1():
24+
pytorch_code = textwrap.dedent(
25+
"""
26+
import torch
27+
x = torch.tensor([[1, 2], [4, 5]])
28+
mat1 = torch.tensor([[1, 2], [4, 5]])
29+
mat2 = torch.tensor([[1, 2], [4, 5]])
30+
result = x.addmm(mat1, mat2)
31+
"""
32+
)
33+
obj.run(pytorch_code, ["result"])
34+
35+
36+
def test_case_2():
37+
pytorch_code = textwrap.dedent(
38+
"""
39+
import torch
40+
x = torch.tensor([[1., 2], [4, 5]])
41+
mat1 = torch.tensor([[1., 2], [4, 5]])
42+
mat2 = torch.tensor([[1., 2], [4, 5]])
43+
result = x.addmm(mat1, mat2, beta=0.6, alpha=0.7)
44+
"""
45+
)
46+
obj.run(pytorch_code, ["result"])
47+
48+
49+
def test_case_3():
50+
pytorch_code = textwrap.dedent(
51+
"""
52+
import torch
53+
x = torch.tensor([[1., 2], [4, 5]])
54+
mat1 = torch.tensor([[1., 2], [4, 5]])
55+
mat2 = torch.tensor([[1., 2], [4, 5]])
56+
result = x.addmm(mat1=mat1, mat2=mat2, beta=0.6, alpha=0.7)
57+
"""
58+
)
59+
obj.run(pytorch_code, ["result"])
60+
61+
62+
def test_case_4():
63+
pytorch_code = textwrap.dedent(
64+
"""
65+
import torch
66+
x = torch.tensor([[1., 2], [4, 5]])
67+
mat1 = torch.tensor([[1., 2], [4, 5]])
68+
result = x.addmm(mat1, torch.tensor([[1., 2], [4, 5]]), beta=0.6, alpha=0.7)
69+
"""
70+
)
71+
obj.run(pytorch_code, ["result"])
72+
73+
74+
def test_case_5():
75+
pytorch_code = textwrap.dedent(
76+
"""
77+
import torch
78+
x = torch.tensor([[1., 2], [4, 5]])
79+
result = x.addmm(torch.tensor([[1., 2], [4, 5]]), torch.tensor([[1., 2], [4, 5]]))
80+
"""
81+
)
82+
obj.run(pytorch_code, ["result"])

tests/test_Tensor_all.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import textwrap
16+
17+
from apibase import APIBase
18+
19+
obj = APIBase("torch.Tensor.all")
20+
21+
22+
def test_case_1():
23+
pytorch_code = textwrap.dedent(
24+
"""
25+
import torch
26+
a = torch.rand(1, 2).bool()
27+
result = a.all()
28+
"""
29+
)
30+
obj.run(pytorch_code, ["result"])
31+
32+
33+
def test_case_2():
34+
pytorch_code = textwrap.dedent(
35+
"""
36+
import torch
37+
a = torch.rand(3, 4)
38+
result = a.all()
39+
"""
40+
)
41+
obj.run(pytorch_code, ["result"])
42+
43+
44+
def test_case_3():
45+
pytorch_code = textwrap.dedent(
46+
"""
47+
import torch
48+
a = torch.rand(4, 3)
49+
result = a.all(1)
50+
"""
51+
)
52+
obj.run(pytorch_code, ["result"])
53+
54+
55+
def test_case_4():
56+
pytorch_code = textwrap.dedent(
57+
"""
58+
import torch
59+
a = torch.rand(4, 3)
60+
result = a.all(1, True)
61+
"""
62+
)
63+
obj.run(pytorch_code, ["result"])
64+
65+
66+
def test_case_5():
67+
pytorch_code = textwrap.dedent(
68+
"""
69+
import torch
70+
a = torch.rand(4, 3)
71+
result = a.all(dim=0, keepdim=False)
72+
"""
73+
)
74+
obj.run(pytorch_code, ["result"])

0 commit comments

Comments
 (0)