27
27
from paddle import set_flags , static
28
28
from paddle .base import core
29
29
from paddle .jit .api import sot_mode_guard
30
+ from paddle .jit .sot .opcode_translator .executor .executor_cache import (
31
+ OpcodeExecutorCache ,
32
+ )
33
+ from paddle .jit .sot .utils .envs import min_graph_size_guard
30
34
31
35
"""
32
36
# Usage:
@@ -54,6 +58,8 @@ def test_case1(self):
54
58
class ToStaticMode (Flag ):
55
59
AST = auto ()
56
60
SOT = auto ()
61
+ # SOT with MIN_GRAPH_SIZE=10, we only test SOT_MGS10 + LEGACY_IR to avoid regression
62
+ SOT_MGS10 = auto ()
57
63
58
64
def lower_case_name (self ):
59
65
return self .name .lower ()
@@ -70,13 +76,15 @@ def lower_case_name(self):
70
76
return self .name .lower ()
71
77
72
78
73
- DEFAULT_TO_STATIC_MODE = ToStaticMode .AST | ToStaticMode .SOT
79
+ DEFAULT_TO_STATIC_MODE = (
80
+ ToStaticMode .AST | ToStaticMode .SOT | ToStaticMode .SOT_MGS10
81
+ )
74
82
DEFAULT_IR_MODE = IrMode .LEGACY_IR
75
83
76
84
77
85
def to_legacy_ast_test (fn ):
78
86
"""
79
- convert run fall_back to ast
87
+ convert run AST
80
88
"""
81
89
82
90
@wraps (fn )
@@ -90,14 +98,34 @@ def impl(*args, **kwargs):
90
98
91
99
def to_sot_test (fn ):
92
100
"""
93
- convert run fall_back to ast
101
+ convert run SOT
94
102
"""
95
103
96
104
@wraps (fn )
97
105
def impl (* args , ** kwargs ):
98
106
logger .info ("[SOT] running SOT" )
107
+
108
+ OpcodeExecutorCache ().clear ()
99
109
with sot_mode_guard (True ):
100
- fn (* args , ** kwargs )
110
+ with min_graph_size_guard (0 ):
111
+ fn (* args , ** kwargs )
112
+
113
+ return impl
114
+
115
+
116
+ def to_sot_mgs10_test (fn ):
117
+ """
118
+ convert run SOT and MIN_GRAPH_SIZE=10
119
+ """
120
+
121
+ @wraps (fn )
122
+ def impl (* args , ** kwargs ):
123
+ logger .info ("[SOT_MGS10] running SOT" )
124
+
125
+ OpcodeExecutorCache ().clear ()
126
+ with sot_mode_guard (True ):
127
+ with min_graph_size_guard (10 ):
128
+ fn (* args , ** kwargs )
101
129
102
130
return impl
103
131
@@ -148,8 +176,9 @@ def impl(*args, **kwargs):
148
176
# Metaclass and BaseClass
149
177
class Dy2StTestMeta (type ):
150
178
TO_STATIC_HANDLER_MAP = {
151
- ToStaticMode .SOT : to_sot_test ,
152
179
ToStaticMode .AST : to_legacy_ast_test ,
180
+ ToStaticMode .SOT : to_sot_test ,
181
+ ToStaticMode .SOT_MGS10 : to_sot_mgs10_test ,
153
182
}
154
183
155
184
IR_HANDLER_MAP = {
@@ -204,6 +233,12 @@ def __new__(cls, name, bases, attrs):
204
233
)
205
234
# Generate all test cases
206
235
for to_static_mode , ir_mode in to_static_with_ir_modes :
236
+ if (
237
+ to_static_mode == ToStaticMode .SOT_MGS10
238
+ and ir_mode != IrMode .LEGACY_IR
239
+ ):
240
+ # SOT_MGS10 only test with LEGACY_IR
241
+ continue
207
242
new_attrs [
208
243
Dy2StTestMeta .test_case_name (
209
244
fn_name , to_static_mode , ir_mode
@@ -262,7 +297,7 @@ def test_ast_only(fn):
262
297
263
298
264
299
def test_sot_only (fn ):
265
- fn = set_to_static_mode (ToStaticMode .SOT )(fn )
300
+ fn = set_to_static_mode (ToStaticMode .SOT | ToStaticMode . SOT_MGS10 )(fn )
266
301
return fn
267
302
268
303
0 commit comments