Skip to content

Commit 0a85d41

Browse files
authored
[Dy2St] Reopen AST+legacy IR uts (#72207)
1 parent 85e1c9a commit 0a85d41

10 files changed

+44
-34
lines changed

test/dygraph_to_static/dygraph_to_static_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def lower_case_name(self):
118118
DEFAULT_IR_MODE = IrMode.PT | IrMode.PIR
119119
DEFAULT_BACKEND_MODE = BackendMode.PHI | BackendMode.CINN
120120
VALID_MODES = [
121-
# For `.pd_model` export, we still need test AST+PT
121+
# For `.pd_model` export, we still need test AST+PT / AST+LEGACY_IR
122+
(ToStaticMode.AST, IrMode.LEGACY_IR, BackendMode.PHI),
122123
(ToStaticMode.AST, IrMode.PT, BackendMode.PHI),
123124
(ToStaticMode.AST, IrMode.PIR, BackendMode.PHI),
124125
(ToStaticMode.SOT, IrMode.PIR, BackendMode.PHI),

test/dygraph_to_static/test_ast_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Dy2StTestBase,
2222
static_guard,
2323
test_ast_only,
24-
test_legacy_and_pir,
24+
test_pir_only,
2525
)
2626
from ifelse_simple_func import (
2727
dyfunc_with_if_else,
@@ -49,7 +49,7 @@ def _ast2func(self, func):
4949
return transformed_func
5050

5151
@test_ast_only
52-
@test_legacy_and_pir
52+
@test_pir_only
5353
def test_ast2func(self):
5454
def func(x, y):
5555
return x + y
@@ -58,7 +58,7 @@ def func(x, y):
5858
self.assertEqual(func(x, y), self._ast2func(func)(x, y))
5959

6060
@test_ast_only
61-
@test_legacy_and_pir
61+
@test_pir_only
6262
def test_ast2func_dygraph(self):
6363
funcs = [dyfunc_with_if_else, dyfunc_with_if_else2, nested_if_else]
6464
x_data = np.random.random([10, 16]).astype('float32')
@@ -69,7 +69,7 @@ def test_ast2func_dygraph(self):
6969
self.assertTrue((true_ret == test_ret).all())
7070

7171
@test_ast_only
72-
@test_legacy_and_pir
72+
@test_pir_only
7373
def test_ast2func_static(self):
7474
def func(x):
7575
y = F.relu(x)
@@ -88,7 +88,7 @@ def func(x):
8888
self.assertTrue((ret[0] == ret[1]).all())
8989

9090
@test_ast_only
91-
@test_legacy_and_pir
91+
@test_pir_only
9292
def test_ast2func_error(self):
9393
with self.assertRaises(Exception) as e:
9494
self.assertRaises(TypeError, ast_to_func("x = a + b", 'foo'))

test/dygraph_to_static/test_dygraph_to_static_utils.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from itertools import product
1617

1718
from dygraph_to_static_utils import (
19+
DEFAULT_BACKEND_MODE,
20+
DEFAULT_IR_MODE,
21+
DEFAULT_TO_STATIC_MODE,
1822
VALID_MODES,
1923
BackendMode,
2024
Dy2StTestBase,
@@ -28,6 +32,18 @@
2832
set_to_static_mode,
2933
)
3034

35+
ALL_MODES = list(product(ToStaticMode, IrMode, BackendMode))
36+
DEFAULT_MODES = [
37+
(to_static_mode, ir_mode, backend_mode)
38+
for (to_static_mode, ir_mode, backend_mode) in ALL_MODES
39+
if (
40+
(to_static_mode, ir_mode, backend_mode) in VALID_MODES
41+
and to_static_mode & DEFAULT_TO_STATIC_MODE
42+
and ir_mode & DEFAULT_IR_MODE
43+
and backend_mode & DEFAULT_BACKEND_MODE
44+
)
45+
]
46+
3147

3248
class CheckTestCaseExistsMixin:
3349
def assert_hasattr(self, obj: object, attr: str):
@@ -89,14 +105,14 @@ def test_check_test_case_basic(self):
89105
test_case = TestCaseBasic()
90106
case_name = "test_basic"
91107
self.assert_not_hasattr(test_case, case_name)
92-
for mode_tuple in VALID_MODES:
108+
for mode_tuple in DEFAULT_MODES:
93109
self.check_test_case_exists(test_case, case_name, mode_tuple)
94110

95111
def test_check_test_case_disable_test_case(self):
96112
test_case = TestCaseDisableTestCase()
97113
case_name = "test_disable_one"
98114
self.assert_not_hasattr(test_case, case_name)
99-
for mode_tuple in VALID_MODES:
115+
for mode_tuple in DEFAULT_MODES:
100116
if mode_tuple == (ToStaticMode.SOT, IrMode.PIR, BackendMode.CINN):
101117
self.check_test_case_not_exists(
102118
test_case, case_name, mode_tuple
@@ -106,7 +122,7 @@ def test_check_test_case_disable_test_case(self):
106122

107123
case_name = "test_disable_multiple"
108124
self.assert_not_hasattr(test_case, case_name)
109-
for mode_tuple in VALID_MODES:
125+
for mode_tuple in DEFAULT_MODES:
110126
if mode_tuple in [
111127
(ToStaticMode.SOT, IrMode.PIR, BackendMode.CINN),
112128
(ToStaticMode.SOT, IrMode.PIR, BackendMode.PHI),
@@ -120,7 +136,7 @@ def test_check_test_case_disable_test_case(self):
120136

121137
case_name = "test_disable_multiple_with_or"
122138
self.assert_not_hasattr(test_case, case_name)
123-
for mode_tuple in VALID_MODES:
139+
for mode_tuple in DEFAULT_MODES:
124140
if mode_tuple in [
125141
(ToStaticMode.SOT, IrMode.PIR, BackendMode.CINN),
126142
(ToStaticMode.SOT, IrMode.PIR, BackendMode.PHI),
@@ -135,7 +151,7 @@ def test_check_test_case_set_mode(self):
135151
test_case = TestCaseSetMode()
136152
case_name = "test_set_to_static_mode"
137153
self.assert_not_hasattr(test_case, case_name)
138-
for mode_tuple in VALID_MODES:
154+
for mode_tuple in DEFAULT_MODES:
139155
to_static_mode, _, _ = mode_tuple
140156
if to_static_mode == ToStaticMode.SOT:
141157
self.check_test_case_exists(test_case, case_name, mode_tuple)
@@ -146,7 +162,7 @@ def test_check_test_case_set_mode(self):
146162

147163
case_name = "test_set_ir_mode"
148164
self.assert_not_hasattr(test_case, case_name)
149-
for mode_tuple in VALID_MODES:
165+
for mode_tuple in DEFAULT_MODES:
150166
_, ir_mode, _ = mode_tuple
151167
if ir_mode == IrMode.PIR:
152168
self.check_test_case_exists(test_case, case_name, mode_tuple)
@@ -157,7 +173,7 @@ def test_check_test_case_set_mode(self):
157173

158174
case_name = "test_set_backend_mode"
159175
self.assert_not_hasattr(test_case, case_name)
160-
for mode_tuple in VALID_MODES:
176+
for mode_tuple in DEFAULT_MODES:
161177
_, _, backend_mode = mode_tuple
162178
if backend_mode == BackendMode.CINN:
163179
self.check_test_case_exists(test_case, case_name, mode_tuple)
@@ -168,7 +184,7 @@ def test_check_test_case_set_mode(self):
168184

169185
case_name = "test_set_all"
170186
self.assert_not_hasattr(test_case, case_name)
171-
for mode_tuple in VALID_MODES:
187+
for mode_tuple in DEFAULT_MODES:
172188
if mode_tuple == (ToStaticMode.SOT, IrMode.PIR, BackendMode.CINN):
173189
self.check_test_case_exists(test_case, case_name, mode_tuple)
174190
else:

test/dygraph_to_static/test_ifelse.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
disable_test_case,
2424
enable_to_static_guard,
2525
test_ast_only,
26-
test_legacy_and_pir,
2726
test_phi_only,
2827
test_pir_only,
2928
)
@@ -553,7 +552,7 @@ def forward(self, a, b, c):
553552

554553

555554
class TestDy2StIfElseBackward(Dy2StTestBase):
556-
@test_legacy_and_pir
555+
@test_pir_only
557556
def test_run_backward(self):
558557
a = paddle.randn((4, 3), dtype='float32')
559558
a.stop_gradient = False

test/dygraph_to_static/test_incubate_jit_inference.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
2020
test_ast_only,
21-
test_legacy_and_pt,
22-
test_legacy_and_pt_and_pir,
2321
)
2422

2523
import paddle
@@ -63,7 +61,6 @@ def forward(self, x_list, bool_value):
6361

6462
class TestToStaticInfenrenceModel(Dy2StTestBase):
6563
@test_ast_only
66-
@test_legacy_and_pt_and_pir
6764
def test_dygraph_static_same_result(self):
6865
hidd = 1024
6966
batch = 4096
@@ -78,7 +75,6 @@ def test_dygraph_static_same_result(self):
7875

7976
class TestToStaticInfenrenceTensorRTModel(Dy2StTestBase):
8077
@test_ast_only
81-
@test_legacy_and_pt
8278
def test_dygraph_static_same_result(self):
8379
if paddle_infer.get_trt_compile_version()[0] == 0:
8480
return
@@ -95,7 +91,6 @@ def test_dygraph_static_same_result(self):
9591

9692
class TestToStaticInfenrenceFunc(Dy2StTestBase):
9793
@test_ast_only
98-
@test_legacy_and_pt_and_pir
9994
def test_dygraph_static_same_result(self):
10095
hidd = 1024
10196
batch = 4096
@@ -118,7 +113,6 @@ def test_dygraph_static_same_result(self):
118113

119114
class TestToStaticInputListModel(Dy2StTestBase):
120115
@test_ast_only
121-
@test_legacy_and_pt
122116
def test_dygraph_static_same_result(self):
123117
hidd = 1024
124118
batch = 4096

test/dygraph_to_static/test_len.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
Dy2StTestBase,
2020
static_guard,
2121
test_ast_only,
22-
test_legacy_and_pt,
2322
test_pir_only,
23+
test_pt_only,
2424
)
2525

2626
import paddle
@@ -165,7 +165,7 @@ def setUp(self):
165165
)
166166

167167
@test_ast_only
168-
@test_legacy_and_pt
168+
@test_pt_only
169169
def test_len_legacy(self):
170170
with static_guard():
171171
(

test/dygraph_to_static/test_partial_program.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
2020
test_ast_only,
21-
test_legacy_and_pt,
2221
test_pir_only,
22+
test_pt_only,
2323
)
2424
from test_fetch_feed import Linear
2525

@@ -132,7 +132,7 @@ def test_nest(self):
132132

133133
class TestWithTrainAndEval(Dy2StTestBase):
134134
@test_ast_only
135-
@test_legacy_and_pt
135+
@test_pt_only
136136
def test_legacy_ir_switch_eval_and_train(self):
137137
# TODO(cleanup-legacy-ir): Remove this test case
138138
linear_net = Linear()
@@ -196,7 +196,7 @@ def test_switch_eval_and_train(self):
196196

197197
class TestWithNoGrad(Dy2StTestBase):
198198
@test_ast_only
199-
@test_legacy_and_pt
199+
@test_pt_only
200200
def test_legacy_ir_with_no_grad(self):
201201
# TODO(cleanup-legacy-ir): Remove this test case
202202
linear_net = Linear()

test/dygraph_to_static/test_place.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
from dygraph_to_static_utils import (
1919
Dy2StTestBase,
20-
test_legacy_and_pt,
2120
test_pir_only,
21+
test_pt_only,
2222
)
2323

2424
import paddle
2525

2626

2727
class TestPlace(Dy2StTestBase):
28-
@test_legacy_and_pt
28+
@test_pt_only
2929
def test_place_legacy(self):
3030
# TODO(cleanup-legacy-ir): remove this test case
3131
paddle.enable_static()

test/dygraph_to_static/test_typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import unittest
1919

2020
import numpy as np
21-
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pir
21+
from dygraph_to_static_utils import Dy2StTestBase, test_pir_only
2222

2323
import paddle
2424

@@ -94,7 +94,7 @@ def run_dy(self):
9494
out, _ = self.net(self.x)
9595
return out
9696

97-
@test_legacy_and_pir
97+
@test_pir_only
9898
def test_type(self):
9999
self.net = self.build_net()
100100
out = self.run_dy()

test/dygraph_to_static/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import types
1616
import unittest
1717

18-
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pir
18+
from dygraph_to_static_utils import Dy2StTestBase, test_pir_only
1919

2020
from paddle.jit.dy2static.transformers.utils import index_in_list
2121
from paddle.jit.dy2static.utils import is_paddle_func
2222

2323

2424
class TestIndexInList(Dy2StTestBase):
25-
@test_legacy_and_pir
25+
@test_pir_only
2626
def test_index_in_list(self):
2727
list_to_test = [1, 2, 3, 4, 5]
2828
self.assertEqual(index_in_list(list_to_test, 4), 3)
@@ -57,7 +57,7 @@ class TestIsPaddle(Dy2StTestBase):
5757
def fake_module(self):
5858
return types.ModuleType('paddlenlp')
5959

60-
@test_legacy_and_pir
60+
@test_pir_only
6161
def test_func(self):
6262
m = self.fake_module()
6363
self.assertFalse(is_paddle_func(m))

0 commit comments

Comments
 (0)