Skip to content

Commit 257f5d3

Browse files
authored
【PIR API adaptor No.257、265】Migrate gcd,lcm into pir (#59600)
1 parent d57954a commit 257f5d3

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

test/legacy_test/test_gcd.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle import base
2120
from paddle.base import core
22-
23-
paddle.enable_static()
21+
from paddle.pir_utils import test_with_pir_api
2422

2523

2624
class TestGcdAPI(unittest.TestCase):
@@ -30,32 +28,31 @@ def setUp(self):
3028
self.x_shape = [1]
3129
self.y_shape = [1]
3230

31+
@test_with_pir_api
3332
def test_static_graph(self):
34-
startup_program = base.Program()
35-
train_program = base.Program()
36-
with base.program_guard(startup_program, train_program):
33+
if core.is_compiled_with_cuda():
34+
place = core.CUDAPlace(0)
35+
else:
36+
place = core.CPUPlace()
37+
with paddle.static.program_guard(
38+
paddle.static.Program(), paddle.static.Program()
39+
):
3740
x = paddle.static.data(
3841
name='input1', dtype='int32', shape=self.x_shape
3942
)
4043
y = paddle.static.data(
4144
name='input2', dtype='int32', shape=self.y_shape
4245
)
4346
out = paddle.gcd(x, y)
47+
out_ref = np.gcd(self.x_np, self.y_np)
4448

45-
place = (
46-
base.CUDAPlace(0)
47-
if core.is_compiled_with_cuda()
48-
else base.CPUPlace()
49-
)
50-
exe = base.Executor(place)
49+
exe = paddle.static.Executor(place)
5150
res = exe.run(
52-
base.default_main_program(),
51+
paddle.static.default_main_program(),
5352
feed={'input1': self.x_np, 'input2': self.y_np},
5453
fetch_list=[out],
5554
)
56-
self.assertTrue(
57-
(np.array(res[0]) == np.gcd(self.x_np, self.y_np)).all()
58-
)
55+
self.assertTrue((res[0] == out_ref).all())
5956

6057
def test_dygraph(self):
6158
paddle.disable_static()
@@ -99,3 +96,8 @@ def setUp(self):
9996
self.y_np = -20
10097
self.x_shape = []
10198
self.y_shape = []
99+
100+
101+
if __name__ == "__main__":
102+
paddle.enable_static()
103+
unittest.main()

test/legacy_test/test_lcm.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
import numpy as np
1818

1919
import paddle
20-
from paddle import base
2120
from paddle.base import core
22-
23-
paddle.enable_static()
21+
from paddle.pir_utils import test_with_pir_api
2422

2523

2624
class TestLcmAPI(unittest.TestCase):
@@ -30,32 +28,30 @@ def setUp(self):
3028
self.x_shape = []
3129
self.y_shape = []
3230

31+
@test_with_pir_api
3332
def test_static_graph(self):
34-
startup_program = base.Program()
35-
train_program = base.Program()
36-
with base.program_guard(startup_program, train_program):
33+
if core.is_compiled_with_cuda():
34+
place = core.CUDAPlace(0)
35+
else:
36+
place = core.CPUPlace()
37+
with paddle.static.program_guard(
38+
paddle.static.Program(), paddle.static.Program()
39+
):
3740
x1 = paddle.static.data(
3841
name='input1', dtype='int32', shape=self.x_shape
3942
)
4043
x2 = paddle.static.data(
4144
name='input2', dtype='int32', shape=self.y_shape
4245
)
4346
out = paddle.lcm(x1, x2)
44-
45-
place = (
46-
base.CUDAPlace(0)
47-
if core.is_compiled_with_cuda()
48-
else base.CPUPlace()
49-
)
50-
exe = base.Executor(place)
47+
out_ref = np.lcm(self.x_np, self.y_np)
48+
exe = paddle.static.Executor(place)
5149
res = exe.run(
52-
base.default_main_program(),
50+
paddle.static.default_main_program(),
5351
feed={'input1': self.x_np, 'input2': self.y_np},
5452
fetch_list=[out],
5553
)
56-
self.assertTrue(
57-
(np.array(res[0]) == np.lcm(self.x_np, self.y_np)).all()
58-
)
54+
self.assertTrue((res[0] == out_ref).all())
5955

6056
def test_dygraph(self):
6157
paddle.disable_static()
@@ -99,3 +95,8 @@ def setUp(self):
9995
self.y_np = -20
10096
self.x_shape = []
10197
self.y_shape = []
98+
99+
100+
if __name__ == "__main__":
101+
paddle.enable_static()
102+
unittest.main()

0 commit comments

Comments
 (0)