Skip to content

Commit 5b8763c

Browse files
authored
[SOT] Add MIN_GRAPH_SIZE=10 test in dy2st tests (#59191)
1 parent 2ba0387 commit 5b8763c

6 files changed

+56
-10
lines changed

test/dygraph_to_static/dygraph_to_static_utils_new.py

+41-6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
from paddle import set_flags, static
2828
from paddle.base import core
2929
from paddle.jit.api import sot_mode_guard
30+
from paddle.jit.sot.opcode_translator.executor.executor_cache import (
31+
OpcodeExecutorCache,
32+
)
33+
from paddle.jit.sot.utils.envs import min_graph_size_guard
3034

3135
"""
3236
# Usage:
@@ -54,6 +58,8 @@ def test_case1(self):
5458
class ToStaticMode(Flag):
5559
AST = auto()
5660
SOT = auto()
61+
# SOT with MIN_GRAPH_SIZE=10, we only test SOT_MGS10 + LEGACY_IR to avoid regression
62+
SOT_MGS10 = auto()
5763

5864
def lower_case_name(self):
5965
return self.name.lower()
@@ -70,13 +76,15 @@ def lower_case_name(self):
7076
return self.name.lower()
7177

7278

73-
DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT
79+
DEFAULT_TO_STATIC_MODE = (
80+
ToStaticMode.AST | ToStaticMode.SOT | ToStaticMode.SOT_MGS10
81+
)
7482
DEFAULT_IR_MODE = IrMode.LEGACY_IR
7583

7684

7785
def to_legacy_ast_test(fn):
7886
"""
79-
convert run fall_back to ast
87+
convert run AST
8088
"""
8189

8290
@wraps(fn)
@@ -90,14 +98,34 @@ def impl(*args, **kwargs):
9098

9199
def to_sot_test(fn):
92100
"""
93-
convert run fall_back to ast
101+
convert run SOT
94102
"""
95103

96104
@wraps(fn)
97105
def impl(*args, **kwargs):
98106
logger.info("[SOT] running SOT")
107+
108+
OpcodeExecutorCache().clear()
99109
with sot_mode_guard(True):
100-
fn(*args, **kwargs)
110+
with min_graph_size_guard(0):
111+
fn(*args, **kwargs)
112+
113+
return impl
114+
115+
116+
def to_sot_mgs10_test(fn):
117+
"""
118+
convert run SOT and MIN_GRAPH_SIZE=10
119+
"""
120+
121+
@wraps(fn)
122+
def impl(*args, **kwargs):
123+
logger.info("[SOT_MGS10] running SOT")
124+
125+
OpcodeExecutorCache().clear()
126+
with sot_mode_guard(True):
127+
with min_graph_size_guard(10):
128+
fn(*args, **kwargs)
101129

102130
return impl
103131

@@ -148,8 +176,9 @@ def impl(*args, **kwargs):
148176
# Metaclass and BaseClass
149177
class Dy2StTestMeta(type):
150178
TO_STATIC_HANDLER_MAP = {
151-
ToStaticMode.SOT: to_sot_test,
152179
ToStaticMode.AST: to_legacy_ast_test,
180+
ToStaticMode.SOT: to_sot_test,
181+
ToStaticMode.SOT_MGS10: to_sot_mgs10_test,
153182
}
154183

155184
IR_HANDLER_MAP = {
@@ -204,6 +233,12 @@ def __new__(cls, name, bases, attrs):
204233
)
205234
# Generate all test cases
206235
for to_static_mode, ir_mode in to_static_with_ir_modes:
236+
if (
237+
to_static_mode == ToStaticMode.SOT_MGS10
238+
and ir_mode != IrMode.LEGACY_IR
239+
):
240+
# SOT_MGS10 only test with LEGACY_IR
241+
continue
207242
new_attrs[
208243
Dy2StTestMeta.test_case_name(
209244
fn_name, to_static_mode, ir_mode
@@ -262,7 +297,7 @@ def test_ast_only(fn):
262297

263298

264299
def test_sot_only(fn):
265-
fn = set_to_static_mode(ToStaticMode.SOT)(fn)
300+
fn = set_to_static_mode(ToStaticMode.SOT | ToStaticMode.SOT_MGS10)(fn)
266301
return fn
267302

268303

test/dygraph_to_static/test_gradname_parse.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import numpy as np
1919
from dygraph_to_static_utils_new import (
2020
Dy2StTestBase,
21-
test_ast_only,
22-
test_pir_api_only,
21+
test_legacy_and_pir_api,
2322
)
2423

2524
import paddle
@@ -86,8 +85,7 @@ def setUp(self):
8685
self.dy2st_input = (x2,)
8786
self.dy2st_grad_input = (x2,)
8887

89-
@test_ast_only
90-
@test_pir_api_only
88+
@test_legacy_and_pir_api
9189
def test_run(self):
9290
try:
9391
dy_out = self.func(*self.dy_input)

test/dygraph_to_static/test_inplace_assign.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def func(x):
5454
@test_legacy_and_pir
5555
def test_case2(self):
5656
def func(a, x):
57+
x = 2 * x
5758
x[:] = a * 2.0
5859
return x
5960

test/dygraph_to_static/test_param_guard.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import unittest
1617

1718
import numpy as np

test/dygraph_to_static/test_seq2seq.py

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
import numpy as np
2121
from dygraph_to_static_utils_new import (
2222
Dy2StTestBase,
23+
IrMode,
24+
ToStaticMode,
25+
disable_test_case,
2326
)
2427
from seq2seq_dygraph_model import AttentionModel, BaseModel
2528
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter
@@ -236,10 +239,13 @@ def _test_predict(self, attn_model=False):
236239
msg=f"\npred_dygraph = {pred_dygraph} \npred_static = {pred_static}",
237240
)
238241

242+
# Disable duplicated test case to avoid timeout
243+
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
239244
def test_base_model(self):
240245
self._test_train(attn_model=False)
241246
self._test_predict(attn_model=False)
242247

248+
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
243249
def test_attn_model(self):
244250
self._test_train(attn_model=True)
245251
# TODO(liym27): add predict

test/dygraph_to_static/test_to_tensor.py

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import numpy
1818
from dygraph_to_static_utils_new import (
1919
Dy2StTestBase,
20+
IrMode,
21+
ToStaticMode,
22+
disable_test_case,
2023
test_legacy_and_pir_exe_and_pir_api,
2124
test_legacy_only,
2225
test_pir_api_only,
@@ -165,7 +168,9 @@ def test_to_tensor_default_dtype(self):
165168
self.assertTrue(a.stop_gradient == b.stop_gradient)
166169
self.assertTrue(a.place._equals(b.place))
167170

171+
# MIN_GRAPH_SIZE=10 will cause fallback and raise error in dygraph
168172
@test_legacy_and_pir_exe_and_pir_api
173+
@disable_test_case((ToStaticMode.SOT_MGS10, IrMode.LEGACY_IR))
169174
def test_to_tensor_err_log(self):
170175
paddle.disable_static()
171176
x = paddle.to_tensor([3])

0 commit comments

Comments
 (0)