Skip to content

Commit 5e03cc4

Browse files
committed
Fix
1 parent f4b789e commit 5e03cc4

7 files changed

+43
-32
lines changed

paconvert/api_mapping.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3321,6 +3321,7 @@
33213321
"min_input_args": 0,
33223322
"args_list": [
33233323
"dtype",
3324+
"dst_type",
33243325
"non_blocking"
33253326
]
33263327
},

paconvert/api_matcher.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,11 @@ def generate_code(self, kwargs):
15631563
if len(kwargs) == 0:
15641564
code = f"str({self.paddleClass}.dtype)"
15651565
else:
1566-
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
1566+
# For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type
1567+
if "dst_type" in kwargs:
1568+
code = f"{self.paddleClass}.astype({kwargs['dst_type']})"
1569+
else:
1570+
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
15671571
return code
15681572

15691573

tests/test_Tensor_signbit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_case_1():
2323
pytorch_code = textwrap.dedent(
2424
"""
2525
import torch
26-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
26+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
2727
result = x.signbit()
2828
"""
2929
)
@@ -34,7 +34,7 @@ def test_case_2():
3434
pytorch_code = textwrap.dedent(
3535
"""
3636
import torch
37-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float64')
37+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float64)
3838
result = x.signbit()
3939
"""
4040
)

tests/test_nn_Module_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def test_case_1():
3333
obj.run(pytorch_code, ["result"])
3434

3535

36-
# Will match torch.nn.Module, the named parameter "dst_type" cannot be resolved.
37-
def _test_case_2():
36+
# Will match torch.Tensor.type to resolve "dst_type" parameter.
37+
def test_case_2():
3838
pytorch_code = textwrap.dedent(
3939
"""
4040
import torch

tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import textwrap
1616

1717
from apibase import APIBase
18-
from lr_scheduler_helper import generate_torch_code
18+
from lr_scheduler_helper import generate_lr_scheduler_test_code
1919

2020
obj = APIBase("torch.optim.lr_scheduler.CosineAnnealingWarmRestarts")
2121

2222

2323
def test_case_1():
2424
pytorch_code = textwrap.dedent(
25-
generate_torch_code(
25+
generate_lr_scheduler_test_code(
2626
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10)"
2727
)
2828
)
@@ -31,7 +31,7 @@ def test_case_1():
3131

3232
def test_case_2():
3333
pytorch_code = textwrap.dedent(
34-
generate_torch_code(
34+
generate_lr_scheduler_test_code(
3535
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10)"
3636
)
3737
)
@@ -40,7 +40,7 @@ def test_case_2():
4040

4141
def test_case_3():
4242
pytorch_code = textwrap.dedent(
43-
generate_torch_code(
43+
generate_lr_scheduler_test_code(
4444
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10)"
4545
)
4646
)
@@ -49,7 +49,7 @@ def test_case_3():
4949

5050
def test_case_4():
5151
pytorch_code = textwrap.dedent(
52-
generate_torch_code(
52+
generate_lr_scheduler_test_code(
5353
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=True)"
5454
)
5555
)
@@ -58,7 +58,7 @@ def test_case_4():
5858

5959
def test_case_5():
6060
pytorch_code = textwrap.dedent(
61-
generate_torch_code(
61+
generate_lr_scheduler_test_code(
6262
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.05, verbose=True)"
6363
)
6464
)
@@ -67,7 +67,7 @@ def test_case_5():
6767

6868
def test_case_6():
6969
pytorch_code = textwrap.dedent(
70-
generate_torch_code(
70+
generate_lr_scheduler_test_code(
7171
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10, 1, 1.0, -1, False)"
7272
)
7373
)
@@ -79,7 +79,7 @@ def test_case_6():
7979
# paddle result has diff with pytorch result
8080
def test_case_7():
8181
pytorch_code = textwrap.dedent(
82-
generate_torch_code(
82+
generate_lr_scheduler_test_code(
8383
[
8484
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=False)",
8585
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)",

tests/test_optim_lr_scheduler_LinearLR.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,23 @@
1515
import textwrap
1616

1717
from apibase import APIBase
18-
from lr_scheduler_helper import generate_torch_code
18+
from lr_scheduler_helper import generate_lr_scheduler_test_code
1919

2020
obj = APIBase("torch.optim.lr_scheduler.LinearLR")
2121

2222

2323
def test_case_1():
2424
pytorch_code = textwrap.dedent(
25-
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)")
25+
generate_lr_scheduler_test_code(
26+
"torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)"
27+
)
2628
)
2729
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)
2830

2931

3032
def test_case_2():
3133
pytorch_code = textwrap.dedent(
32-
generate_torch_code(
34+
generate_lr_scheduler_test_code(
3335
"torch.optim.lr_scheduler.LinearLR(sgd, start_factor=0.05, end_factor=1.0)"
3436
)
3537
)
@@ -38,21 +40,25 @@ def test_case_2():
3840

3941
def test_case_3():
4042
pytorch_code = textwrap.dedent(
41-
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)")
43+
generate_lr_scheduler_test_code(
44+
"torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)"
45+
)
4246
)
4347
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)
4448

4549

4650
def test_case_4():
4751
pytorch_code = textwrap.dedent(
48-
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)")
52+
generate_lr_scheduler_test_code(
53+
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)"
54+
)
4955
)
5056
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)
5157

5258

5359
def test_case_5():
5460
pytorch_code = textwrap.dedent(
55-
generate_torch_code(
61+
generate_lr_scheduler_test_code(
5662
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3)"
5763
)
5864
)
@@ -61,7 +67,7 @@ def test_case_5():
6167

6268
def test_case_6():
6369
pytorch_code = textwrap.dedent(
64-
generate_torch_code(
70+
generate_lr_scheduler_test_code(
6571
"torch.optim.lr_scheduler.LinearLR(start_factor=0.05, end_factor=1.0, total_iters=3, optimizer=sgd)"
6672
)
6773
)
@@ -70,7 +76,7 @@ def test_case_6():
7076

7177
def test_case_7():
7278
pytorch_code = textwrap.dedent(
73-
generate_torch_code(
79+
generate_lr_scheduler_test_code(
7480
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1.0, 3, -1, False)"
7581
)
7682
)
@@ -79,7 +85,7 @@ def test_case_7():
7985

8086
def test_case_8():
8187
pytorch_code = textwrap.dedent(
82-
generate_torch_code(
88+
generate_lr_scheduler_test_code(
8389
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)"
8490
)
8591
)
@@ -88,7 +94,7 @@ def test_case_8():
8894

8995
def test_case_9():
9096
pytorch_code = textwrap.dedent(
91-
generate_torch_code(
97+
generate_lr_scheduler_test_code(
9298
[
9399
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)",
94100
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=scheduler_1.last_epoch, verbose=False)",
@@ -100,6 +106,6 @@ def test_case_9():
100106

101107
def test_case_10():
102108
pytorch_code = textwrap.dedent(
103-
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd)")
109+
generate_lr_scheduler_test_code("torch.optim.lr_scheduler.LinearLR(sgd)")
104110
)
105111
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)

tests/test_signbit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_case_1():
2323
pytorch_code = textwrap.dedent(
2424
"""
2525
import torch
26-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
26+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
2727
result = torch.signbit(x)
2828
"""
2929
)
@@ -34,7 +34,7 @@ def test_case_2():
3434
pytorch_code = textwrap.dedent(
3535
"""
3636
import torch
37-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
37+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
3838
result = torch.signbit(input=x)
3939
"""
4040
)
@@ -45,8 +45,8 @@ def test_case_3():
4545
pytorch_code = textwrap.dedent(
4646
"""
4747
import torch
48-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
49-
out = torch.tensor([])
48+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
49+
out = torch.tensor([], dtype=torch.bool)
5050
result = torch.signbit(out=out, input=x)
5151
"""
5252
)
@@ -57,8 +57,8 @@ def test_case_4():
5757
pytorch_code = textwrap.dedent(
5858
"""
5959
import torch
60-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
61-
out = torch.tensor([])
60+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
61+
out = torch.tensor([], dtype=torch.bool)
6262
result = torch.signbit(input=x, out=out)
6363
"""
6464
)
@@ -69,8 +69,8 @@ def test_case_5():
6969
pytorch_code = textwrap.dedent(
7070
"""
7171
import torch
72-
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
73-
out = torch.tensor([])
72+
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
73+
out = torch.tensor([], dtype=torch.bool)
7474
result = torch.signbit(x, out=out)
7575
"""
7676
)

0 commit comments

Comments
 (0)